1use std::collections::{HashMap, VecDeque};
16use std::fs::{File, OpenOptions};
17use std::io::Write;
18use std::path::PathBuf;
19#[cfg(feature = "simd")]
20use std::sync::atomic::{AtomicBool, Ordering};
21use std::sync::{Arc, RwLock};
22
23use torsh_core::{
24 dtype::TensorElement,
25 error::{Result, TorshError},
26};
27
28#[cfg(feature = "simd")]
30use scirs2_core::simd_aligned::AlignedVec;
31
32#[cfg(unix)]
33use std::os::unix::fs::FileExt;
34#[cfg(windows)]
35use std::os::windows::fs::FileExt;
36
37const MEMORY_MAPPING_THRESHOLD: usize = 1024 * 1024 * 1024;
39
40#[cfg(feature = "simd")]
43const ALIGNED_STORAGE_THRESHOLD: usize = 1024;
44
45#[cfg(feature = "simd")]
48const SIMD_OPTIMIZED_THRESHOLD: usize = 10240;
49
50#[cfg(feature = "simd")]
67pub struct SimdStorage<T> {
68 data: AlignedVec<T>,
70 shared: AtomicBool,
72}
73
74#[cfg(feature = "simd")]
75impl<T> SimdStorage<T> {
76 pub fn new(data: AlignedVec<T>) -> Self {
78 Self {
79 data,
80 shared: AtomicBool::new(false),
81 }
82 }
83
84 pub fn len(&self) -> usize {
86 self.data.len()
87 }
88
89 pub fn is_empty(&self) -> bool {
91 self.data.is_empty()
92 }
93
94 pub fn as_slice(&self) -> &[T] {
96 self.data.as_slice()
97 }
98
99 pub fn capacity(&self) -> usize {
101 self.data.capacity()
102 }
103
104 pub fn mark_shared(&self) {
106 self.shared.store(true, Ordering::SeqCst);
107 }
108
109 pub fn is_shared(&self) -> bool {
111 self.shared.load(Ordering::SeqCst)
112 }
113}
114
115#[cfg(feature = "simd")]
116impl<T: Copy> SimdStorage<T> {
117 pub fn as_mut_slice_if_unique(&mut self) -> Option<&mut [T]> {
121 if self.shared.load(Ordering::SeqCst) {
122 None } else {
124 Some(self.data.as_mut_slice())
125 }
126 }
127
128 pub fn to_vec(&self) -> Vec<T> {
130 self.data.as_slice().to_vec()
131 }
132}
133
134#[cfg(feature = "simd")]
135impl<T> std::fmt::Debug for SimdStorage<T> {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 f.debug_struct("SimdStorage")
138 .field("len", &self.data.len())
139 .field("shared", &self.shared.load(Ordering::Relaxed))
140 .finish()
141 }
142}
143
144pub enum TensorStorage<T: TensorElement> {
146 InMemory(Arc<RwLock<Vec<T>>>),
148 MemoryMapped(Arc<RwLock<MemoryMappedStorage<T>>>),
150 #[cfg(feature = "simd")]
152 Aligned(Arc<RwLock<AlignedVec<T>>>),
153 #[cfg(feature = "simd")]
160 SimdOptimized(Arc<SimdStorage<T>>),
161}
162
163impl<T: TensorElement> std::fmt::Debug for TensorStorage<T> {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 match self {
166 Self::InMemory(data) => f.debug_tuple("InMemory").field(data).finish(),
167 Self::MemoryMapped(storage) => f.debug_tuple("MemoryMapped").field(storage).finish(),
168 #[cfg(feature = "simd")]
169 Self::Aligned(_) => f.debug_tuple("Aligned").field(&"<AlignedVec>").finish(),
170 #[cfg(feature = "simd")]
171 Self::SimdOptimized(storage) => f.debug_tuple("SimdOptimized").field(storage).finish(),
172 }
173 }
174}
175
176#[derive(Debug)]
178pub struct MemoryMappedStorage<T: TensorElement> {
179 file: File,
181 file_path: PathBuf,
183 num_elements: usize,
185 cache: HashMap<usize, T>,
187 max_cache_size: usize,
189 access_pattern: VecDeque<usize>,
191 is_temporary: bool,
193}
194
195impl<T: TensorElement + Copy> TensorStorage<T> {
196 pub fn in_memory(data: Vec<T>) -> Self {
198 Self::InMemory(Arc::new(RwLock::new(data)))
199 }
200
201 pub fn memory_mapped(data: Vec<T>, file_path: Option<PathBuf>) -> Result<Self> {
203 let storage = MemoryMappedStorage::new(data, file_path)?;
204 Ok(Self::MemoryMapped(Arc::new(RwLock::new(storage))))
205 }
206
207 #[cfg(feature = "simd")]
209 pub fn aligned(data: Vec<T>) -> Result<Self> {
210 let mut aligned_vec = AlignedVec::with_capacity(data.len()).map_err(|e| {
211 TorshError::InvalidArgument(format!("Failed to create aligned storage: {e}"))
212 })?;
213
214 for item in data {
215 aligned_vec.push(item);
216 }
217
218 Ok(Self::Aligned(Arc::new(RwLock::new(aligned_vec))))
219 }
220
221 pub fn fast_result(data: Vec<T>) -> Self {
231 Self::InMemory(Arc::new(RwLock::new(data)))
232 }
233
234 #[cfg(feature = "simd")]
246 pub fn simd_optimized(data: Vec<T>) -> Result<Self> {
247 let mut aligned_vec = AlignedVec::with_capacity(data.len()).map_err(|e| {
248 TorshError::InvalidArgument(format!("Failed to create SIMD storage: {e}"))
249 })?;
250
251 for item in data {
252 aligned_vec.push(item);
253 }
254
255 let simd_storage = SimdStorage::new(aligned_vec);
256 Ok(Self::SimdOptimized(Arc::new(simd_storage)))
257 }
258
259 pub fn create_optimal(data: Vec<T>) -> Result<Self> {
267 let size_bytes = data.len() * std::mem::size_of::<T>();
268
269 if size_bytes >= MEMORY_MAPPING_THRESHOLD {
270 Self::memory_mapped(data, None)
272 } else {
273 #[cfg(feature = "simd")]
274 {
275 if size_bytes >= SIMD_OPTIMIZED_THRESHOLD {
276 return Self::simd_optimized(data);
279 } else if size_bytes >= ALIGNED_STORAGE_THRESHOLD {
280 return Self::aligned(data);
282 }
283 }
284 Ok(Self::in_memory(data))
286 }
287 }
288
289 pub fn len(&self) -> usize {
291 match self {
292 Self::InMemory(data) => {
293 data.read().map(|guard| guard.len()).unwrap_or(0) }
295 Self::MemoryMapped(storage) => {
296 storage.read().map(|guard| guard.num_elements).unwrap_or(0) }
298 #[cfg(feature = "simd")]
299 Self::Aligned(data) => {
300 data.read().map(|guard| guard.len()).unwrap_or(0) }
302 #[cfg(feature = "simd")]
303 Self::SimdOptimized(storage) => storage.len(), }
305 }
306
307 pub fn is_empty(&self) -> bool {
309 self.len() == 0
310 }
311
312 pub fn get(&self, index: usize) -> Result<T>
314 where
315 T: Copy,
316 {
317 match self {
318 Self::InMemory(data) => {
319 let data_guard = data.read().map_err(|_| {
320 TorshError::SynchronizationError("Lock poisoned during read".to_string())
321 })?;
322 data_guard
323 .get(index)
324 .copied()
325 .ok_or_else(|| TorshError::IndexOutOfBounds {
326 index,
327 size: data_guard.len(),
328 })
329 }
330 Self::MemoryMapped(storage) => storage
331 .write()
332 .map_err(|_| {
333 TorshError::SynchronizationError("Lock poisoned during write".to_string())
334 })?
335 .get(index),
336 #[cfg(feature = "simd")]
337 Self::Aligned(data) => {
338 let data_guard = data.read().map_err(|_| {
339 TorshError::SynchronizationError("Lock poisoned during read".to_string())
340 })?;
341 if index >= data_guard.len() {
342 Err(TorshError::IndexOutOfBounds {
343 index,
344 size: data_guard.len(),
345 })
346 } else {
347 Ok(data_guard.as_slice()[index])
348 }
349 }
350 #[cfg(feature = "simd")]
351 Self::SimdOptimized(storage) => {
352 let slice = storage.as_slice();
354 if index >= slice.len() {
355 Err(TorshError::IndexOutOfBounds {
356 index,
357 size: slice.len(),
358 })
359 } else {
360 Ok(slice[index])
361 }
362 }
363 }
364 }
365
366 pub fn set(&self, index: usize, value: T) -> Result<()>
368 where
369 T: Copy,
370 {
371 match self {
372 Self::InMemory(data) => {
373 let mut data_guard = data.write().map_err(|_| {
374 TorshError::SynchronizationError("Lock poisoned during write".to_string())
375 })?;
376 if index >= data_guard.len() {
377 return Err(TorshError::IndexOutOfBounds {
378 index,
379 size: data_guard.len(),
380 });
381 }
382 data_guard[index] = value;
383 Ok(())
384 }
385 Self::MemoryMapped(storage) => storage
386 .write()
387 .map_err(|_| {
388 TorshError::SynchronizationError("Lock poisoned during write".to_string())
389 })?
390 .set(index, value),
391 #[cfg(feature = "simd")]
392 Self::Aligned(data) => {
393 let mut data_guard = data.write().map_err(|_| {
394 TorshError::SynchronizationError("Lock poisoned during write".to_string())
395 })?;
396 if index >= data_guard.len() {
397 return Err(TorshError::IndexOutOfBounds {
398 index,
399 size: data_guard.len(),
400 });
401 }
402 (*data_guard).set(index, value);
404 Ok(())
405 }
406 #[cfg(feature = "simd")]
407 Self::SimdOptimized(_storage) => {
408 Err(TorshError::InvalidArgument(
411 "SimdOptimized storage is immutable. Use make_unique() for mutable access."
412 .to_string(),
413 ))
414 }
415 }
416 }
417
418 pub fn get_slice(&self, start: usize, len: usize) -> Result<Vec<T>>
420 where
421 T: Copy,
422 {
423 match self {
424 Self::InMemory(data) => {
425 let data_guard = data.read().map_err(|_| {
426 TorshError::SynchronizationError("Lock poisoned during read".to_string())
427 })?;
428 if start + len > data_guard.len() {
429 return Err(TorshError::IndexOutOfBounds {
430 index: start + len - 1,
431 size: data_guard.len(),
432 });
433 }
434 Ok(data_guard[start..start + len].to_vec())
435 }
436 Self::MemoryMapped(storage) => storage
437 .write()
438 .map_err(|_| {
439 TorshError::SynchronizationError("Lock poisoned during write".to_string())
440 })?
441 .get_slice(start, len),
442 #[cfg(feature = "simd")]
443 Self::Aligned(data) => {
444 let data_guard = data.read().map_err(|_| {
445 TorshError::SynchronizationError("Lock poisoned during read".to_string())
446 })?;
447 if start + len > data_guard.len() {
448 return Err(TorshError::IndexOutOfBounds {
449 index: start + len - 1,
450 size: data_guard.len(),
451 });
452 }
453 let slice = data_guard.as_slice();
454 Ok(slice[start..start + len].to_vec())
455 }
456 #[cfg(feature = "simd")]
457 Self::SimdOptimized(storage) => {
458 let slice = storage.as_slice();
460 if start + len > slice.len() {
461 return Err(TorshError::IndexOutOfBounds {
462 index: start + len - 1,
463 size: slice.len(),
464 });
465 }
466 Ok(slice[start..start + len].to_vec())
467 }
468 }
469 }
470
471 pub fn set_slice(&self, start: usize, values: &[T]) -> Result<()>
473 where
474 T: Copy,
475 {
476 match self {
477 Self::InMemory(data) => {
478 let mut data_guard = data.write().map_err(|_| {
479 TorshError::SynchronizationError("Lock poisoned during write".to_string())
480 })?;
481 if start + values.len() > data_guard.len() {
482 return Err(TorshError::IndexOutOfBounds {
483 index: start + values.len() - 1,
484 size: data_guard.len(),
485 });
486 }
487 data_guard[start..start + values.len()].copy_from_slice(values);
488 Ok(())
489 }
490 Self::MemoryMapped(storage) => storage
491 .write()
492 .map_err(|_| {
493 TorshError::SynchronizationError("Lock poisoned during write".to_string())
494 })?
495 .set_slice(start, values),
496 #[cfg(feature = "simd")]
497 Self::Aligned(data) => {
498 let mut data_guard = data.write().map_err(|_| {
499 TorshError::SynchronizationError("Lock poisoned during write".to_string())
500 })?;
501 if start + values.len() > data_guard.len() {
502 return Err(TorshError::IndexOutOfBounds {
503 index: start + values.len() - 1,
504 size: data_guard.len(),
505 });
506 }
507 let slice = data_guard.as_mut_slice();
509 slice[start..start + values.len()].copy_from_slice(values);
510 Ok(())
511 }
512 #[cfg(feature = "simd")]
513 Self::SimdOptimized(_storage) => {
514 Err(TorshError::InvalidArgument(
516 "SimdOptimized storage is immutable. Use make_unique() for mutable access."
517 .to_string(),
518 ))
519 }
520 }
521 }
522
523 pub fn to_vec(&self) -> Result<Vec<T>>
525 where
526 T: Copy,
527 {
528 match self {
529 Self::InMemory(data) => Ok(data
530 .read()
531 .map_err(|_| {
532 TorshError::SynchronizationError("Lock poisoned during read".to_string())
533 })?
534 .clone()),
535 Self::MemoryMapped(storage) => storage
536 .write()
537 .map_err(|_| {
538 TorshError::SynchronizationError("Lock poisoned during write".to_string())
539 })?
540 .to_vec(),
541 #[cfg(feature = "simd")]
542 Self::Aligned(data) => {
543 let data_guard = data.read().map_err(|_| {
544 TorshError::SynchronizationError("Lock poisoned during read".to_string())
545 })?;
546 Ok(data_guard.as_slice().to_vec())
547 }
548 #[cfg(feature = "simd")]
549 Self::SimdOptimized(storage) => {
550 Ok(storage.to_vec())
552 }
553 }
554 }
555
556 pub fn storage_type(&self) -> &'static str {
558 match self {
559 Self::InMemory(_) => "in_memory",
560 Self::MemoryMapped(_) => "memory_mapped",
561 #[cfg(feature = "simd")]
562 Self::Aligned(_) => "aligned_simd",
563 #[cfg(feature = "simd")]
564 Self::SimdOptimized(_) => "simd_optimized",
565 }
566 }
567
568 pub fn memory_usage(&self) -> usize {
570 match self {
571 Self::InMemory(data) => {
572 data.read()
573 .map(|guard| guard.len() * std::mem::size_of::<T>())
574 .unwrap_or(0) }
576 Self::MemoryMapped(storage) => {
577 storage
578 .read()
579 .map(|storage_guard| {
580 storage_guard.cache.len() * std::mem::size_of::<T>()
582 + std::mem::size_of::<MemoryMappedStorage<T>>()
583 })
584 .unwrap_or(std::mem::size_of::<MemoryMappedStorage<T>>()) }
586 #[cfg(feature = "simd")]
587 Self::Aligned(data) => {
588 data.read()
589 .map(|data_guard| {
590 data_guard.capacity() * std::mem::size_of::<T>()
592 })
593 .unwrap_or(0) }
595 #[cfg(feature = "simd")]
596 Self::SimdOptimized(storage) => {
597 storage.capacity() * std::mem::size_of::<T>()
599 }
600 }
601 }
602
603 pub fn with_slice<R, F>(&self, f: F) -> Result<R>
626 where
627 F: FnOnce(&[T]) -> Result<R>,
628 T: Copy,
629 {
630 match self {
631 Self::InMemory(data) => {
632 let data_guard = data.read().map_err(|_| {
633 TorshError::SynchronizationError("Lock poisoned during read".to_string())
634 })?;
635 f(data_guard.as_slice())
636 }
637 Self::MemoryMapped(storage) => {
638 let vec = storage
640 .write()
641 .map_err(|_| {
642 TorshError::SynchronizationError("Lock poisoned during write".to_string())
643 })?
644 .to_vec()?;
645 f(&vec)
646 }
647 #[cfg(feature = "simd")]
648 Self::Aligned(data) => {
649 let data_guard = data.read().map_err(|_| {
650 TorshError::SynchronizationError("Lock poisoned during read".to_string())
651 })?;
652 f(data_guard.as_slice())
653 }
654 #[cfg(feature = "simd")]
655 Self::SimdOptimized(storage) => {
656 f(storage.as_slice())
658 }
659 }
660 }
661
662 #[cfg(feature = "simd")]
671 pub fn try_as_slice_direct(&self) -> Option<&[T]> {
672 match self {
673 Self::SimdOptimized(storage) => Some(storage.as_slice()),
674 _ => None,
675 }
676 }
677
678 pub fn with_slice_mut<R, F>(&self, f: F) -> Result<R>
702 where
703 F: FnOnce(&mut [T]) -> Result<R>,
704 T: Copy,
705 {
706 match self {
707 Self::InMemory(data) => {
708 let mut data_guard = data.write().map_err(|_| {
709 TorshError::SynchronizationError("Lock poisoned during write".to_string())
710 })?;
711 f(data_guard.as_mut_slice())
712 }
713 Self::MemoryMapped(_) => {
714 Err(TorshError::InvalidArgument(
716 "Memory-mapped storage does not support mutable slice access".to_string(),
717 ))
718 }
719 #[cfg(feature = "simd")]
720 Self::Aligned(data) => {
721 let mut data_guard = data.write().map_err(|_| {
722 TorshError::SynchronizationError("Lock poisoned during write".to_string())
723 })?;
724 f(data_guard.as_mut_slice())
725 }
726 #[cfg(feature = "simd")]
727 Self::SimdOptimized(_) => {
728 Err(TorshError::InvalidArgument(
730 "SimdOptimized storage is immutable. Use make_unique() for mutable access."
731 .to_string(),
732 ))
733 }
734 }
735 }
736}
737
738impl<T: TensorElement> MemoryMappedStorage<T> {
739 pub fn new(data: Vec<T>, file_path: Option<PathBuf>) -> Result<Self> {
741 let (file_path, is_temporary) = match file_path {
742 Some(path) => (path, false),
743 None => {
744 let temp_dir = std::env::temp_dir();
746 let temp_file = temp_dir.join(format!("torsh_tensor_{}.mmap", std::process::id()));
747 (temp_file, true)
748 }
749 };
750
751 let mut file = OpenOptions::new()
753 .create(true)
754 .read(true)
755 .write(true)
756 .truncate(true)
757 .open(&file_path)
758 .map_err(|e| {
759 TorshError::IoError(format!("Failed to create memory-mapped file: {e}"))
760 })?;
761
762 let data_bytes = unsafe {
764 std::slice::from_raw_parts(
765 data.as_ptr() as *const u8,
766 data.len() * std::mem::size_of::<T>(),
767 )
768 };
769 file.write_all(data_bytes).map_err(|e| {
770 TorshError::IoError(format!("Failed to write to memory-mapped file: {e}"))
771 })?;
772 file.flush()
773 .map_err(|e| TorshError::IoError(format!("Failed to flush memory-mapped file: {e}")))?;
774
775 Ok(Self {
776 file,
777 file_path,
778 num_elements: data.len(),
779 cache: HashMap::new(),
780 max_cache_size: 10000, access_pattern: VecDeque::new(),
782 is_temporary,
783 })
784 }
785
786 pub fn get(&mut self, index: usize) -> Result<T>
788 where
789 T: Copy,
790 {
791 if index >= self.num_elements {
792 return Err(TorshError::IndexOutOfBounds {
793 index,
794 size: self.num_elements,
795 });
796 }
797
798 if let Some(&value) = self.cache.get(&index) {
800 self.update_access_pattern(index);
801 return Ok(value);
802 }
803
804 let value = self.read_element_from_file(index)?;
806
807 if self.cache.len() < self.max_cache_size {
809 self.cache.insert(index, value);
810 } else {
811 self.evict_lru();
813 self.cache.insert(index, value);
814 }
815
816 self.update_access_pattern(index);
817 Ok(value)
818 }
819
820 pub fn set(&mut self, index: usize, value: T) -> Result<()>
822 where
823 T: Copy,
824 {
825 if index >= self.num_elements {
826 return Err(TorshError::IndexOutOfBounds {
827 index,
828 size: self.num_elements,
829 });
830 }
831
832 self.cache.insert(index, value);
834
835 self.write_element_to_file(index, value)?;
837 self.update_access_pattern(index);
838 Ok(())
839 }
840
841 pub fn get_slice(&mut self, start: usize, len: usize) -> Result<Vec<T>>
843 where
844 T: Copy,
845 {
846 if start + len > self.num_elements {
847 return Err(TorshError::IndexOutOfBounds {
848 index: start + len - 1,
849 size: self.num_elements,
850 });
851 }
852
853 let mut result = Vec::with_capacity(len);
854 for i in 0..len {
855 result.push(self.get(start + i)?);
856 }
857 Ok(result)
858 }
859
860 pub fn set_slice(&mut self, start: usize, values: &[T]) -> Result<()>
862 where
863 T: Copy,
864 {
865 if start + values.len() > self.num_elements {
866 return Err(TorshError::IndexOutOfBounds {
867 index: start + values.len() - 1,
868 size: self.num_elements,
869 });
870 }
871
872 for (i, &value) in values.iter().enumerate() {
873 self.set(start + i, value)?;
874 }
875 Ok(())
876 }
877
878 pub fn to_vec(&mut self) -> Result<Vec<T>>
880 where
881 T: Copy,
882 {
883 self.get_slice(0, self.num_elements)
884 }
885
886 fn read_element_from_file(&mut self, index: usize) -> Result<T>
888 where
889 T: Copy,
890 {
891 let offset = index * std::mem::size_of::<T>();
892 let mut buffer = vec![0u8; std::mem::size_of::<T>()];
893
894 #[cfg(unix)]
895 {
896 self.file
897 .read_exact_at(&mut buffer, offset as u64)
898 .map_err(|e| {
899 TorshError::IoError(format!("Failed to read from memory-mapped file: {e}"))
900 })?;
901 }
902
903 #[cfg(windows)]
904 {
905 self.file
906 .seek_read(&mut buffer, offset as u64)
907 .map_err(|e| {
908 TorshError::IoError(format!("Failed to read from memory-mapped file: {e}"))
909 })?;
910 }
911
912 #[cfg(not(any(unix, windows)))]
913 {
914 self.file
915 .seek(SeekFrom::Start(offset as u64))
916 .map_err(|e| {
917 TorshError::IoError(format!("Failed to seek in memory-mapped file: {e}"))
918 })?;
919 self.file.read_exact(&mut buffer).map_err(|e| {
920 TorshError::IoError(format!("Failed to read from memory-mapped file: {e}"))
921 })?;
922 }
923
924 let value = unsafe { std::ptr::read(buffer.as_ptr() as *const T) };
926 Ok(value)
927 }
928
929 fn write_element_to_file(&mut self, index: usize, value: T) -> Result<()>
931 where
932 T: Copy,
933 {
934 let offset = index * std::mem::size_of::<T>();
935 let buffer = unsafe {
936 std::slice::from_raw_parts(&value as *const T as *const u8, std::mem::size_of::<T>())
937 };
938
939 #[cfg(unix)]
940 {
941 self.file.write_all_at(buffer, offset as u64).map_err(|e| {
942 TorshError::IoError(format!("Failed to write to memory-mapped file: {e}"))
943 })?;
944 }
945
946 #[cfg(windows)]
947 {
948 self.file.seek_write(buffer, offset as u64).map_err(|e| {
949 TorshError::IoError(format!("Failed to write to memory-mapped file: {e}"))
950 })?;
951 }
952
953 #[cfg(not(any(unix, windows)))]
954 {
955 self.file
956 .seek(SeekFrom::Start(offset as u64))
957 .map_err(|e| {
958 TorshError::IoError(format!("Failed to seek in memory-mapped file: {e}"))
959 })?;
960 self.file.write_all(buffer).map_err(|e| {
961 TorshError::IoError(format!("Failed to write to memory-mapped file: {e}"))
962 })?;
963 }
964
965 Ok(())
966 }
967
968 fn update_access_pattern(&mut self, index: usize) {
970 self.access_pattern.push_back(index);
971 if self.access_pattern.len() > self.max_cache_size {
972 self.access_pattern.pop_front();
973 }
974 }
975
976 fn evict_lru(&mut self) {
978 if let Some(lru_index) = self.access_pattern.front().copied() {
979 self.cache.remove(&lru_index);
980 }
981 }
982}
983
984impl<T: TensorElement> Drop for MemoryMappedStorage<T> {
985 fn drop(&mut self) {
986 if self.is_temporary {
987 let _ = std::fs::remove_file(&self.file_path);
989 }
990 }
991}
992
993impl<T: TensorElement> Clone for TensorStorage<T> {
994 fn clone(&self) -> Self {
995 match self {
996 Self::InMemory(data) => Self::InMemory(Arc::clone(data)),
997 Self::MemoryMapped(storage) => Self::MemoryMapped(Arc::clone(storage)),
998 #[cfg(feature = "simd")]
999 Self::Aligned(data) => Self::Aligned(Arc::clone(data)),
1000 #[cfg(feature = "simd")]
1001 Self::SimdOptimized(storage) => {
1002 storage.mark_shared();
1004 Self::SimdOptimized(Arc::clone(storage))
1005 }
1006 }
1007 }
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012 use super::*;
1013
1014 #[test]
1015 fn test_in_memory_storage() {
1016 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1017 let storage = TensorStorage::in_memory(data.clone());
1018
1019 assert_eq!(storage.len(), 4);
1020 assert!(!storage.is_empty());
1021 assert_eq!(storage.storage_type(), "in_memory");
1022
1023 assert_eq!(storage.get(0).expect("get(0) failed"), 1.0);
1024 assert_eq!(storage.get(3).expect("get(3) failed"), 4.0);
1025
1026 let slice = storage.get_slice(1, 2).expect("get_slice failed");
1027 assert_eq!(slice, vec![2.0, 3.0]);
1028 }
1029
1030 #[test]
1031 fn test_optimal_storage_selection() {
1032 let small_data = vec![1.0f32; 200];
1034 let small_storage =
1035 TensorStorage::create_optimal(small_data).expect("create_optimal failed");
1036
1037 #[cfg(feature = "simd")]
1038 {
1039 assert_eq!(small_storage.storage_type(), "in_memory");
1041 }
1042 #[cfg(not(feature = "simd"))]
1043 {
1044 assert_eq!(small_storage.storage_type(), "in_memory");
1046 }
1047 }
1048
1049 #[test]
1050 fn test_memory_usage_calculation() {
1051 let data = vec![1.0f32; 1000];
1052 let storage = TensorStorage::in_memory(data);
1053 let expected_size = 1000 * std::mem::size_of::<f32>();
1054 assert_eq!(storage.memory_usage(), expected_size);
1055 }
1056
1057 #[test]
1058 #[cfg(feature = "simd")]
1059 fn test_aligned_storage() {
1060 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1061 let storage =
1062 TensorStorage::aligned(data.clone()).expect("aligned storage creation failed");
1063
1064 assert_eq!(storage.len(), 4);
1065 assert!(!storage.is_empty());
1066 assert_eq!(storage.storage_type(), "aligned_simd");
1067
1068 assert_eq!(storage.get(0).expect("get(0) failed"), 1.0);
1070 assert_eq!(storage.get(3).expect("get(3) failed"), 4.0);
1071
1072 let slice = storage.get_slice(1, 2).expect("get_slice failed");
1074 assert_eq!(slice, vec![2.0, 3.0]);
1075
1076 let vec = storage.to_vec().expect("to_vec failed");
1078 assert_eq!(vec, data);
1079 }
1080
1081 #[test]
1082 #[cfg(feature = "simd")]
1083 fn test_optimal_storage_selection_with_aligned() {
1084 let medium_data = vec![1.0f32; 2000]; let medium_storage = TensorStorage::create_optimal(medium_data)
1087 .expect("create_optimal for medium data failed");
1088 assert_eq!(medium_storage.storage_type(), "aligned_simd");
1089
1090 let small_data = vec![1.0f32; 100]; let small_storage = TensorStorage::create_optimal(small_data)
1093 .expect("create_optimal for small data failed");
1094 assert_eq!(small_storage.storage_type(), "in_memory");
1095 }
1096}