1use parking_lot::Mutex;
16use std::{mem, slice};
17use torsh_core::error::{Result, TorshError};
18
19pub struct ZeroCopyTensor<T> {
24 data_ptr: *const T,
25 shape: Vec<usize>,
26 stride: Vec<usize>,
27 capacity: usize,
28 owned: bool,
29}
30
31impl<T> ZeroCopyTensor<T> {
32 pub unsafe fn from_raw_parts(
43 data_ptr: *const T,
44 shape: Vec<usize>,
45 stride: Vec<usize>,
46 ) -> Self {
47 let capacity = shape.iter().product();
48 Self {
49 data_ptr,
50 shape,
51 stride,
52 capacity,
53 owned: false,
54 }
55 }
56
57 pub fn from_slice(data: &[T], shape: Vec<usize>) -> Self {
62 let capacity = shape.iter().product();
63 assert_eq!(
64 data.len(),
65 capacity,
66 "Data length must match tensor capacity"
67 );
68
69 let stride = Self::compute_stride(&shape);
70 Self {
71 data_ptr: data.as_ptr(),
72 shape,
73 stride,
74 capacity,
75 owned: false,
76 }
77 }
78
79 pub fn from_vec(data: Vec<T>, shape: Vec<usize>) -> Self {
84 let capacity = shape.iter().product();
85 assert_eq!(
86 data.len(),
87 capacity,
88 "Data length must match tensor capacity"
89 );
90
91 let stride = Self::compute_stride(&shape);
92 let data_ptr = data.as_ptr();
93 mem::forget(data); Self {
96 data_ptr,
97 shape,
98 stride,
99 capacity,
100 owned: true,
101 }
102 }
103
104 pub fn shape(&self) -> &[usize] {
106 &self.shape
107 }
108
109 pub fn stride(&self) -> &[usize] {
111 &self.stride
112 }
113
114 pub fn len(&self) -> usize {
116 self.capacity
117 }
118
119 pub fn is_empty(&self) -> bool {
121 self.capacity == 0
122 }
123
124 pub fn as_slice(&self) -> &[T] {
130 unsafe { slice::from_raw_parts(self.data_ptr, self.capacity) }
131 }
132
133 fn compute_stride(shape: &[usize]) -> Vec<usize> {
137 let mut stride = vec![1; shape.len()];
138 for i in (0..shape.len().saturating_sub(1)).rev() {
139 stride[i] = stride[i + 1] * shape[i + 1];
140 }
141 stride
142 }
143
144 pub fn slice_view(&self, ranges: &[(usize, usize)]) -> Result<ZeroCopyTensor<T>> {
149 if ranges.len() != self.shape.len() {
150 return Err(TorshError::InvalidArgument(
151 "Number of slice ranges must match tensor dimensions".to_string(),
152 ));
153 }
154
155 let mut new_shape = Vec::new();
156 let mut offset = 0;
157
158 for (i, &(start, end)) in ranges.iter().enumerate() {
159 if start >= end || end > self.shape[i] {
160 return Err(TorshError::InvalidArgument(
161 "Invalid slice range".to_string(),
162 ));
163 }
164 new_shape.push(end - start);
165 offset += start * self.stride[i];
166 }
167
168 let new_stride = self.stride.clone();
169 let new_data_ptr = unsafe { self.data_ptr.add(offset) };
170 let capacity = new_shape.iter().product();
171
172 Ok(ZeroCopyTensor {
173 data_ptr: new_data_ptr,
174 shape: new_shape,
175 stride: new_stride,
176 capacity,
177 owned: false,
178 })
179 }
180
181 pub fn ndim(&self) -> usize {
183 self.shape.len()
184 }
185
186 pub fn is_owned(&self) -> bool {
188 self.owned
189 }
190}
191
192unsafe impl<T: Send> Send for ZeroCopyTensor<T> {}
194
195unsafe impl<T: Sync> Sync for ZeroCopyTensor<T> {}
197
198impl<T> Drop for ZeroCopyTensor<T> {
199 fn drop(&mut self) {
200 if self.owned {
201 unsafe {
202 let _vec =
204 Vec::from_raw_parts(self.data_ptr as *mut T, self.capacity, self.capacity);
205 }
206 }
207 }
208}
209
210pub struct TensorPool<T> {
215 pool: Mutex<Vec<Vec<T>>>,
216 max_size: usize,
217}
218
219impl<T: Clone + Default> TensorPool<T> {
220 pub fn new(max_size: usize) -> Self {
225 Self {
226 pool: Mutex::new(Vec::new()),
227 max_size,
228 }
229 }
230
231 pub fn get(&self, capacity: usize) -> Vec<T> {
236 let mut pool = self.pool.lock();
237
238 for i in 0..pool.len() {
240 if pool[i].capacity() >= capacity {
241 let mut tensor = pool.swap_remove(i);
242 tensor.clear();
243 tensor.resize(capacity, T::default());
244 return tensor;
245 }
246 }
247
248 vec![T::default(); capacity]
250 }
251
252 pub fn return_tensor(&self, tensor: Vec<T>) {
256 let mut pool = self.pool.lock();
257 if pool.len() < self.max_size {
258 pool.push(tensor);
259 }
260 }
262
263 pub fn pool_size(&self) -> usize {
265 self.pool.lock().len()
266 }
267
268 pub fn clear(&self) {
270 self.pool.lock().clear();
271 }
272}
273
274pub struct MemoryMappedLoader {
279 file_path: std::path::PathBuf,
280}
281
282impl MemoryMappedLoader {
283 pub fn new<P: AsRef<std::path::Path>>(file_path: P) -> Result<Self> {
288 let file_path = file_path.as_ref().to_path_buf();
289
290 if !file_path.exists() {
292 return Err(TorshError::InvalidArgument(format!(
293 "File does not exist: {}",
294 file_path.display()
295 )));
296 }
297
298 Ok(Self { file_path })
299 }
300
301 pub fn file_path(&self) -> &std::path::Path {
303 &self.file_path
304 }
305
306 pub fn file_size(&self) -> Result<u64> {
308 std::fs::metadata(&self.file_path)
309 .map(|metadata| metadata.len())
310 .map_err(|e| TorshError::InvalidArgument(format!("Failed to get file size: {}", e)))
311 }
312
313 pub fn load_slice(&self, _offset: usize, _length: usize) -> Result<&[u8]> {
318 Err(TorshError::UnsupportedOperation {
321 op: "memory mapping".to_string(),
322 dtype: "MemoryMappedLoader".to_string(),
323 })
324 }
325
326 pub fn can_map(&self) -> bool {
328 false
331 }
332}
333
334pub struct BufferManager<T> {
339 available_buffers: Mutex<Vec<Vec<T>>>,
340 max_buffers: usize,
341 buffer_size: usize,
342}
343
344impl<T: Clone + Default> BufferManager<T> {
345 pub fn new(max_buffers: usize, buffer_size: usize) -> Self {
351 let mut available_buffers = Vec::with_capacity(max_buffers);
352 for _ in 0..max_buffers {
353 available_buffers.push(vec![T::default(); buffer_size]);
354 }
355
356 Self {
357 available_buffers: Mutex::new(available_buffers),
358 max_buffers,
359 buffer_size,
360 }
361 }
362
363 pub fn acquire_buffer(&self) -> Option<Vec<T>> {
367 let mut available = self.available_buffers.lock();
368 available.pop()
369 }
370
371 pub fn release_buffer(&self, buffer: Vec<T>) {
375 let mut available = self.available_buffers.lock();
376 if available.len() < self.max_buffers {
377 available.push(buffer);
378 }
379 }
380
381 pub fn available_count(&self) -> usize {
383 self.available_buffers.lock().len()
384 }
385
386 pub fn in_use_count(&self) -> usize {
388 self.max_buffers - self.available_count()
389 }
390
391 pub fn buffer_size(&self) -> usize {
393 self.buffer_size
394 }
395
396 pub fn max_buffers(&self) -> usize {
398 self.max_buffers
399 }
400
401 pub fn reset(&self) {
403 let mut available = self.available_buffers.lock();
404 available.clear();
405 for _ in 0..self.max_buffers {
406 available.push(vec![T::default(); self.buffer_size]);
407 }
408 }
409}
410
411pub fn zero_copy_from_vec<T>(data: Vec<T>, shape: Vec<usize>) -> ZeroCopyTensor<T> {
413 ZeroCopyTensor::from_vec(data, shape)
414}
415
416pub fn zero_copy_from_slice<T>(data: &[T], shape: Vec<usize>) -> ZeroCopyTensor<T> {
418 ZeroCopyTensor::from_slice(data, shape)
419}
420
421pub fn create_tensor_pool<T: Clone + Default>(max_size: usize) -> TensorPool<T> {
423 TensorPool::new(max_size)
424}
425
426pub fn create_buffer_manager<T: Clone + Default>(
428 max_buffers: usize,
429 buffer_size: usize,
430) -> BufferManager<T> {
431 BufferManager::new(max_buffers, buffer_size)
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_zero_copy_tensor_from_vec() {
440 let data = vec![1, 2, 3, 4, 5, 6];
441 let shape = vec![2, 3];
442 let tensor = ZeroCopyTensor::from_vec(data, shape.clone());
443
444 assert_eq!(tensor.shape(), &[2, 3]);
445 assert_eq!(tensor.len(), 6);
446 assert!(!tensor.is_empty());
447 assert!(tensor.is_owned());
448 assert_eq!(tensor.ndim(), 2);
449 }
450
451 #[test]
452 fn test_zero_copy_tensor_from_slice() {
453 let data = vec![1, 2, 3, 4];
454 let shape = vec![2, 2];
455 let tensor = ZeroCopyTensor::from_slice(&data, shape.clone());
456
457 assert_eq!(tensor.shape(), &[2, 2]);
458 assert_eq!(tensor.len(), 4);
459 assert!(!tensor.is_owned());
460 assert_eq!(tensor.as_slice(), &[1, 2, 3, 4]);
461 }
462
463 #[test]
464 fn test_zero_copy_tensor_slice_view() {
465 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
466 let shape = vec![3, 3];
467 let tensor = ZeroCopyTensor::from_vec(data, shape);
468
469 let ranges = vec![(1, 3), (1, 3)];
471 let slice_view = tensor.slice_view(&ranges).unwrap();
472
473 assert_eq!(slice_view.shape(), &[2, 2]);
474 assert_eq!(slice_view.len(), 4);
475 assert!(!slice_view.is_owned());
476 }
477
478 #[test]
479 fn test_tensor_pool() {
480 let pool = TensorPool::<f32>::new(3);
481 assert_eq!(pool.pool_size(), 0);
482
483 let tensor1 = pool.get(10);
485 assert_eq!(tensor1.len(), 10);
486
487 pool.return_tensor(tensor1);
489 assert_eq!(pool.pool_size(), 1);
490
491 let tensor2 = pool.get(10);
493 assert_eq!(tensor2.len(), 10);
494 assert_eq!(pool.pool_size(), 0);
495
496 pool.return_tensor(tensor2);
497 pool.clear();
498 assert_eq!(pool.pool_size(), 0);
499 }
500
501 #[test]
502 fn test_buffer_manager() {
503 let manager = BufferManager::<u8>::new(2, 100);
504 assert_eq!(manager.available_count(), 2);
505 assert_eq!(manager.in_use_count(), 0);
506 assert_eq!(manager.buffer_size(), 100);
507 assert_eq!(manager.max_buffers(), 2);
508
509 let buffer1 = manager.acquire_buffer().unwrap();
511 assert_eq!(buffer1.len(), 100);
512 assert_eq!(manager.available_count(), 1);
513
514 let buffer2 = manager.acquire_buffer().unwrap();
515 assert_eq!(manager.available_count(), 0);
516
517 assert!(manager.acquire_buffer().is_none());
519
520 manager.release_buffer(buffer1);
522 assert_eq!(manager.available_count(), 1);
523
524 manager.release_buffer(buffer2);
525 assert_eq!(manager.available_count(), 2);
526
527 manager.reset();
529 assert_eq!(manager.available_count(), 2);
530 }
531
532 #[test]
533 fn test_memory_mapped_loader() {
534 let result = MemoryMappedLoader::new("/non/existent/file");
536 assert!(result.is_err());
537
538 if let Ok(loader) = MemoryMappedLoader::new("/dev/null") {
540 let result = loader.load_slice(0, 10);
541 assert!(result.is_err());
542 assert!(!loader.can_map());
543 }
544 }
545
546 #[test]
547 fn test_stride_computation() {
548 let stride = ZeroCopyTensor::<f32>::compute_stride(&[3, 4]);
550 assert_eq!(stride, vec![4, 1]);
551
552 let stride = ZeroCopyTensor::<f32>::compute_stride(&[2, 3, 4]);
554 assert_eq!(stride, vec![12, 4, 1]);
555
556 let stride = ZeroCopyTensor::<f32>::compute_stride(&[5]);
558 assert_eq!(stride, vec![1]);
559 }
560
561 #[test]
562 fn test_convenience_functions() {
563 let data = vec![1, 2, 3, 4];
564 let shape = vec![2, 2];
565
566 let _tensor_from_vec = zero_copy_from_vec(data.clone(), shape.clone());
567 let _tensor_from_slice = zero_copy_from_slice(&data, shape);
568 let _pool = create_tensor_pool::<f32>(10);
569 let _manager = create_buffer_manager::<u8>(5, 100);
570 }
571
572 #[test]
573 #[should_panic(expected = "Data length must match tensor capacity")]
574 fn test_shape_mismatch() {
575 let data = vec![1, 2, 3];
576 let shape = vec![2, 2]; ZeroCopyTensor::from_vec(data, shape);
578 }
579}