Skip to main content

torsh_data/
zero_copy.rs

1//! Zero-copy data loading utilities for efficient memory management
2//!
3//! This module provides zero-copy tensor operations and memory management
4//! utilities for high-performance data processing pipelines. It enables
5//! efficient handling of large datasets without unnecessary memory allocations.
6//!
7//! # Features
8//!
9//! - **Zero-copy tensors**: Work directly with existing memory without copying
10//! - **Memory pools**: Reuse allocated tensors to reduce allocation overhead
11//! - **Buffer management**: Efficient buffer reuse in data pipelines
12//! - **Memory mapping**: Direct access to file data without loading into memory
13//! - **Thread-safe operations**: Concurrent access to shared memory pools
14
15use parking_lot::Mutex;
16use std::{mem, slice};
17use torsh_core::error::{Result, TorshError};
18
19/// Zero-copy tensor wrapper that avoids unnecessary memory allocation
20///
21/// This struct provides a view into existing memory without copying data.
22/// It can work with borrowed slices or take ownership of allocated memory.
23pub struct ZeroCopyTensor<T> {
24    data_ptr: *const T,
25    shape: Vec<usize>,
26    stride: Vec<usize>,
27    capacity: usize,
28    owned: bool,
29}
30
31impl<T> ZeroCopyTensor<T> {
32    /// Create a zero-copy tensor from existing data without copying
33    ///
34    /// # Safety
35    ///
36    /// This function is unsafe because it directly uses raw pointers. The caller must ensure:
37    /// - `data_ptr` is a valid pointer to a memory region that contains at least `capacity` elements
38    /// - The memory region remains valid for the lifetime of the ZeroCopyTensor
39    /// - The memory is properly aligned for type T
40    /// - The shape and stride parameters correctly describe the tensor layout
41    /// - No other code mutates the memory region while this tensor exists
42    pub unsafe fn from_raw_parts(
43        data_ptr: *const T,
44        shape: Vec<usize>,
45        stride: Vec<usize>,
46    ) -> Self {
47        let capacity = shape.iter().product();
48        Self {
49            data_ptr,
50            shape,
51            stride,
52            capacity,
53            owned: false,
54        }
55    }
56
57    /// Create a zero-copy tensor from a slice
58    ///
59    /// This creates a view into the provided slice without copying data.
60    /// The slice must remain valid for the lifetime of the tensor.
61    pub fn from_slice(data: &[T], shape: Vec<usize>) -> Self {
62        let capacity = shape.iter().product();
63        assert_eq!(
64            data.len(),
65            capacity,
66            "Data length must match tensor capacity"
67        );
68
69        let stride = Self::compute_stride(&shape);
70        Self {
71            data_ptr: data.as_ptr(),
72            shape,
73            stride,
74            capacity,
75            owned: false,
76        }
77    }
78
79    /// Create a zero-copy tensor by taking ownership of a Vec
80    ///
81    /// This transfers ownership of the Vec's memory to the tensor,
82    /// avoiding the need to copy data.
83    pub fn from_vec(data: Vec<T>, shape: Vec<usize>) -> Self {
84        let capacity = shape.iter().product();
85        assert_eq!(
86            data.len(),
87            capacity,
88            "Data length must match tensor capacity"
89        );
90
91        let stride = Self::compute_stride(&shape);
92        let data_ptr = data.as_ptr();
93        mem::forget(data); // Transfer ownership to the tensor
94
95        Self {
96            data_ptr,
97            shape,
98            stride,
99            capacity,
100            owned: true,
101        }
102    }
103
104    /// Get the shape of the tensor
105    pub fn shape(&self) -> &[usize] {
106        &self.shape
107    }
108
109    /// Get the stride of the tensor
110    pub fn stride(&self) -> &[usize] {
111        &self.stride
112    }
113
114    /// Get the total number of elements
115    pub fn len(&self) -> usize {
116        self.capacity
117    }
118
119    /// Check if the tensor is empty
120    pub fn is_empty(&self) -> bool {
121        self.capacity == 0
122    }
123
124    /// Get data as a slice
125    ///
126    /// # Safety
127    /// This is safe as long as the tensor was constructed properly and
128    /// the underlying memory remains valid.
129    pub fn as_slice(&self) -> &[T] {
130        unsafe { slice::from_raw_parts(self.data_ptr, self.capacity) }
131    }
132
133    /// Compute stride from shape (row-major order)
134    ///
135    /// Stride indicates how many elements to skip when moving along each dimension.
136    fn compute_stride(shape: &[usize]) -> Vec<usize> {
137        let mut stride = vec![1; shape.len()];
138        for i in (0..shape.len().saturating_sub(1)).rev() {
139            stride[i] = stride[i + 1] * shape[i + 1];
140        }
141        stride
142    }
143
144    /// Create a view into a subregion without copying data
145    ///
146    /// This creates a new tensor that views a slice of the current tensor.
147    /// The data is not copied, only the view parameters are adjusted.
148    pub fn slice_view(&self, ranges: &[(usize, usize)]) -> Result<ZeroCopyTensor<T>> {
149        if ranges.len() != self.shape.len() {
150            return Err(TorshError::InvalidArgument(
151                "Number of slice ranges must match tensor dimensions".to_string(),
152            ));
153        }
154
155        let mut new_shape = Vec::new();
156        let mut offset = 0;
157
158        for (i, &(start, end)) in ranges.iter().enumerate() {
159            if start >= end || end > self.shape[i] {
160                return Err(TorshError::InvalidArgument(
161                    "Invalid slice range".to_string(),
162                ));
163            }
164            new_shape.push(end - start);
165            offset += start * self.stride[i];
166        }
167
168        let new_stride = self.stride.clone();
169        let new_data_ptr = unsafe { self.data_ptr.add(offset) };
170        let capacity = new_shape.iter().product();
171
172        Ok(ZeroCopyTensor {
173            data_ptr: new_data_ptr,
174            shape: new_shape,
175            stride: new_stride,
176            capacity,
177            owned: false,
178        })
179    }
180
181    /// Get the number of dimensions
182    pub fn ndim(&self) -> usize {
183        self.shape.len()
184    }
185
186    /// Check if this tensor owns its memory
187    pub fn is_owned(&self) -> bool {
188        self.owned
189    }
190}
191
192// Safety: ZeroCopyTensor can be sent between threads if T is Send
193unsafe impl<T: Send> Send for ZeroCopyTensor<T> {}
194
195// Safety: ZeroCopyTensor can be shared between threads if T is Sync
196unsafe impl<T: Sync> Sync for ZeroCopyTensor<T> {}
197
198impl<T> Drop for ZeroCopyTensor<T> {
199    fn drop(&mut self) {
200        if self.owned {
201            unsafe {
202                // Convert back to Vec and let it handle deallocation
203                let _vec =
204                    Vec::from_raw_parts(self.data_ptr as *mut T, self.capacity, self.capacity);
205            }
206        }
207    }
208}
209
210/// Memory pool for reusing allocated tensors to avoid allocation/deallocation overhead
211///
212/// This pool maintains a collection of pre-allocated vectors that can be reused
213/// to avoid the overhead of memory allocation and deallocation in tight loops.
214pub struct TensorPool<T> {
215    pool: Mutex<Vec<Vec<T>>>,
216    max_size: usize,
217}
218
219impl<T: Clone + Default> TensorPool<T> {
220    /// Create a new tensor pool
221    ///
222    /// # Arguments
223    /// * `max_size` - Maximum number of tensors to keep in the pool
224    pub fn new(max_size: usize) -> Self {
225        Self {
226            pool: Mutex::new(Vec::new()),
227            max_size,
228        }
229    }
230
231    /// Get a tensor from the pool or allocate a new one
232    ///
233    /// If a suitable tensor is available in the pool, it will be reused.
234    /// Otherwise, a new tensor will be allocated.
235    pub fn get(&self, capacity: usize) -> Vec<T> {
236        let mut pool = self.pool.lock();
237
238        // Look for a tensor with sufficient capacity
239        for i in 0..pool.len() {
240            if pool[i].capacity() >= capacity {
241                let mut tensor = pool.swap_remove(i);
242                tensor.clear();
243                tensor.resize(capacity, T::default());
244                return tensor;
245            }
246        }
247
248        // No suitable tensor found, allocate a new one
249        vec![T::default(); capacity]
250    }
251
252    /// Return a tensor to the pool
253    ///
254    /// The tensor will be stored in the pool for future reuse if there's space.
255    pub fn return_tensor(&self, tensor: Vec<T>) {
256        let mut pool = self.pool.lock();
257        if pool.len() < self.max_size {
258            pool.push(tensor);
259        }
260        // If pool is full, the tensor will be dropped and deallocated
261    }
262
263    /// Get the number of tensors currently in the pool
264    pub fn pool_size(&self) -> usize {
265        self.pool.lock().len()
266    }
267
268    /// Clear all tensors from the pool
269    pub fn clear(&self) {
270        self.pool.lock().clear();
271    }
272}
273
274/// Memory-mapped data loader for large datasets
275///
276/// This provides a way to access file data directly without loading it entirely into memory.
277/// Useful for working with datasets larger than available RAM.
278pub struct MemoryMappedLoader {
279    file_path: std::path::PathBuf,
280}
281
282impl MemoryMappedLoader {
283    /// Create a new memory-mapped loader
284    ///
285    /// # Arguments
286    /// * `file_path` - Path to the file to be memory-mapped
287    pub fn new<P: AsRef<std::path::Path>>(file_path: P) -> Result<Self> {
288        let file_path = file_path.as_ref().to_path_buf();
289
290        // Verify file exists
291        if !file_path.exists() {
292            return Err(TorshError::InvalidArgument(format!(
293                "File does not exist: {}",
294                file_path.display()
295            )));
296        }
297
298        Ok(Self { file_path })
299    }
300
301    /// Get the file path
302    pub fn file_path(&self) -> &std::path::Path {
303        &self.file_path
304    }
305
306    /// Get file size in bytes
307    pub fn file_size(&self) -> Result<u64> {
308        std::fs::metadata(&self.file_path)
309            .map(|metadata| metadata.len())
310            .map_err(|e| TorshError::InvalidArgument(format!("Failed to get file size: {}", e)))
311    }
312
313    /// Load data without copying (placeholder implementation)
314    ///
315    /// In a full implementation with memmap2 dependency, this would return
316    /// a slice directly from the memory-mapped file.
317    pub fn load_slice(&self, _offset: usize, _length: usize) -> Result<&[u8]> {
318        // In a real implementation, this would return a slice directly from the memory map
319        // using something like: &self.mmap[offset..offset + length]
320        Err(TorshError::UnsupportedOperation {
321            op: "memory mapping".to_string(),
322            dtype: "MemoryMappedLoader".to_string(),
323        })
324    }
325
326    /// Check if the file can be memory-mapped
327    pub fn can_map(&self) -> bool {
328        // For now, always return false since we don't have memmap2
329        // In a real implementation, this would check file accessibility
330        false
331    }
332}
333
334/// Buffer manager for efficient buffer reuse in data pipelines
335///
336/// This manages a pool of pre-allocated buffers that can be acquired and released
337/// by data processing operations to avoid repeated allocation/deallocation.
338pub struct BufferManager<T> {
339    available_buffers: Mutex<Vec<Vec<T>>>,
340    max_buffers: usize,
341    buffer_size: usize,
342}
343
344impl<T: Clone + Default> BufferManager<T> {
345    /// Create a new buffer manager
346    ///
347    /// # Arguments
348    /// * `max_buffers` - Maximum number of buffers to maintain
349    /// * `buffer_size` - Size of each buffer in elements
350    pub fn new(max_buffers: usize, buffer_size: usize) -> Self {
351        let mut available_buffers = Vec::with_capacity(max_buffers);
352        for _ in 0..max_buffers {
353            available_buffers.push(vec![T::default(); buffer_size]);
354        }
355
356        Self {
357            available_buffers: Mutex::new(available_buffers),
358            max_buffers,
359            buffer_size,
360        }
361    }
362
363    /// Acquire a buffer from the pool
364    ///
365    /// Returns `Some(buffer)` if a buffer is available, `None` if all buffers are in use.
366    pub fn acquire_buffer(&self) -> Option<Vec<T>> {
367        let mut available = self.available_buffers.lock();
368        available.pop()
369    }
370
371    /// Release a buffer back to the pool
372    ///
373    /// The buffer will be returned to the pool if there's space, otherwise it will be dropped.
374    pub fn release_buffer(&self, buffer: Vec<T>) {
375        let mut available = self.available_buffers.lock();
376        if available.len() < self.max_buffers {
377            available.push(buffer);
378        }
379    }
380
381    /// Get number of available buffers
382    pub fn available_count(&self) -> usize {
383        self.available_buffers.lock().len()
384    }
385
386    /// Get number of buffers in use
387    pub fn in_use_count(&self) -> usize {
388        self.max_buffers - self.available_count()
389    }
390
391    /// Get the configured buffer size
392    pub fn buffer_size(&self) -> usize {
393        self.buffer_size
394    }
395
396    /// Get the maximum number of buffers
397    pub fn max_buffers(&self) -> usize {
398        self.max_buffers
399    }
400
401    /// Reset all buffers (clear and return to pool)
402    pub fn reset(&self) {
403        let mut available = self.available_buffers.lock();
404        available.clear();
405        for _ in 0..self.max_buffers {
406            available.push(vec![T::default(); self.buffer_size]);
407        }
408    }
409}
410
411/// Convenience function to create a zero-copy tensor from a vector
412pub fn zero_copy_from_vec<T>(data: Vec<T>, shape: Vec<usize>) -> ZeroCopyTensor<T> {
413    ZeroCopyTensor::from_vec(data, shape)
414}
415
416/// Convenience function to create a zero-copy tensor from a slice
417pub fn zero_copy_from_slice<T>(data: &[T], shape: Vec<usize>) -> ZeroCopyTensor<T> {
418    ZeroCopyTensor::from_slice(data, shape)
419}
420
421/// Convenience function to create a tensor pool
422pub fn create_tensor_pool<T: Clone + Default>(max_size: usize) -> TensorPool<T> {
423    TensorPool::new(max_size)
424}
425
426/// Convenience function to create a buffer manager
427pub fn create_buffer_manager<T: Clone + Default>(
428    max_buffers: usize,
429    buffer_size: usize,
430) -> BufferManager<T> {
431    BufferManager::new(max_buffers, buffer_size)
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_zero_copy_tensor_from_vec() {
440        let data = vec![1, 2, 3, 4, 5, 6];
441        let shape = vec![2, 3];
442        let tensor = ZeroCopyTensor::from_vec(data, shape.clone());
443
444        assert_eq!(tensor.shape(), &[2, 3]);
445        assert_eq!(tensor.len(), 6);
446        assert!(!tensor.is_empty());
447        assert!(tensor.is_owned());
448        assert_eq!(tensor.ndim(), 2);
449    }
450
451    #[test]
452    fn test_zero_copy_tensor_from_slice() {
453        let data = vec![1, 2, 3, 4];
454        let shape = vec![2, 2];
455        let tensor = ZeroCopyTensor::from_slice(&data, shape.clone());
456
457        assert_eq!(tensor.shape(), &[2, 2]);
458        assert_eq!(tensor.len(), 4);
459        assert!(!tensor.is_owned());
460        assert_eq!(tensor.as_slice(), &[1, 2, 3, 4]);
461    }
462
463    #[test]
464    fn test_zero_copy_tensor_slice_view() {
465        let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
466        let shape = vec![3, 3];
467        let tensor = ZeroCopyTensor::from_vec(data, shape);
468
469        // Create a 2x2 slice starting at (1,1)
470        let ranges = vec![(1, 3), (1, 3)];
471        let slice_view = tensor.slice_view(&ranges).unwrap();
472
473        assert_eq!(slice_view.shape(), &[2, 2]);
474        assert_eq!(slice_view.len(), 4);
475        assert!(!slice_view.is_owned());
476    }
477
478    #[test]
479    fn test_tensor_pool() {
480        let pool = TensorPool::<f32>::new(3);
481        assert_eq!(pool.pool_size(), 0);
482
483        // Get a tensor
484        let tensor1 = pool.get(10);
485        assert_eq!(tensor1.len(), 10);
486
487        // Return it to the pool
488        pool.return_tensor(tensor1);
489        assert_eq!(pool.pool_size(), 1);
490
491        // Get it back (should be reused)
492        let tensor2 = pool.get(10);
493        assert_eq!(tensor2.len(), 10);
494        assert_eq!(pool.pool_size(), 0);
495
496        pool.return_tensor(tensor2);
497        pool.clear();
498        assert_eq!(pool.pool_size(), 0);
499    }
500
501    #[test]
502    fn test_buffer_manager() {
503        let manager = BufferManager::<u8>::new(2, 100);
504        assert_eq!(manager.available_count(), 2);
505        assert_eq!(manager.in_use_count(), 0);
506        assert_eq!(manager.buffer_size(), 100);
507        assert_eq!(manager.max_buffers(), 2);
508
509        // Acquire buffers
510        let buffer1 = manager.acquire_buffer().unwrap();
511        assert_eq!(buffer1.len(), 100);
512        assert_eq!(manager.available_count(), 1);
513
514        let buffer2 = manager.acquire_buffer().unwrap();
515        assert_eq!(manager.available_count(), 0);
516
517        // No more buffers available
518        assert!(manager.acquire_buffer().is_none());
519
520        // Release buffers
521        manager.release_buffer(buffer1);
522        assert_eq!(manager.available_count(), 1);
523
524        manager.release_buffer(buffer2);
525        assert_eq!(manager.available_count(), 2);
526
527        // Test reset
528        manager.reset();
529        assert_eq!(manager.available_count(), 2);
530    }
531
532    #[test]
533    fn test_memory_mapped_loader() {
534        // Test with a non-existent file
535        let result = MemoryMappedLoader::new("/non/existent/file");
536        assert!(result.is_err());
537
538        // Test loading slice (should fail with unsupported operation)
539        if let Ok(loader) = MemoryMappedLoader::new("/dev/null") {
540            let result = loader.load_slice(0, 10);
541            assert!(result.is_err());
542            assert!(!loader.can_map());
543        }
544    }
545
546    #[test]
547    fn test_stride_computation() {
548        // Test 2D tensor stride
549        let stride = ZeroCopyTensor::<f32>::compute_stride(&[3, 4]);
550        assert_eq!(stride, vec![4, 1]);
551
552        // Test 3D tensor stride
553        let stride = ZeroCopyTensor::<f32>::compute_stride(&[2, 3, 4]);
554        assert_eq!(stride, vec![12, 4, 1]);
555
556        // Test 1D tensor stride
557        let stride = ZeroCopyTensor::<f32>::compute_stride(&[5]);
558        assert_eq!(stride, vec![1]);
559    }
560
561    #[test]
562    fn test_convenience_functions() {
563        let data = vec![1, 2, 3, 4];
564        let shape = vec![2, 2];
565
566        let _tensor_from_vec = zero_copy_from_vec(data.clone(), shape.clone());
567        let _tensor_from_slice = zero_copy_from_slice(&data, shape);
568        let _pool = create_tensor_pool::<f32>(10);
569        let _manager = create_buffer_manager::<u8>(5, 100);
570    }
571
572    #[test]
573    #[should_panic(expected = "Data length must match tensor capacity")]
574    fn test_shape_mismatch() {
575        let data = vec![1, 2, 3];
576        let shape = vec![2, 2]; // Requires 4 elements, but data has 3
577        ZeroCopyTensor::from_vec(data, shape);
578    }
579}