rustorch/tensor/memory/
optimization.rs1use super::aligned::{SimdAllocator, SIMD_ALIGNMENT};
5use crate::error::{RusTorchError, RusTorchResult};
6use crate::tensor::Tensor;
7use num_traits::Float;
8
9#[derive(Debug, Clone)]
12pub struct TensorMemoryInfo {
13 pub total_elements: usize,
16 pub element_size: usize,
19 pub total_bytes: usize,
22 pub is_contiguous: bool,
25 pub alignment: usize,
28 pub is_on_gpu: bool,
31 pub device: String,
34}
35
36pub trait MemoryOptimization<T: Float> {
39 fn memory_info(&self) -> TensorMemoryInfo;
42
43 fn can_optimize_memory(&self) -> bool;
46
47 fn optimize_memory(&self) -> Self;
50
51 fn try_optimize_memory(&self) -> RusTorchResult<Self>
54 where
55 Self: Sized;
56}
57
58impl<T: Float + Clone + 'static> MemoryOptimization<T> for Tensor<T> {
59 fn memory_info(&self) -> TensorMemoryInfo {
60 let element_size = std::mem::size_of::<T>();
61 let total_elements = self.data.len();
62 let total_bytes = total_elements * element_size;
63
64 let ptr = self.data.as_ptr();
66 let alignment = if SimdAllocator::is_aligned(ptr) {
67 SIMD_ALIGNMENT
68 } else {
69 if (ptr as usize) % 16 == 0 {
71 16
72 } else if (ptr as usize) % 8 == 0 {
73 8
74 } else if (ptr as usize) % 4 == 0 {
75 4
76 } else {
77 1
78 }
79 };
80
81 let is_on_gpu = self.is_on_gpu();
82 let device = self.device_type().to_string();
83
84 TensorMemoryInfo {
85 total_elements,
86 element_size,
87 total_bytes,
88 is_contiguous: self.data.is_standard_layout(),
89 alignment,
90 is_on_gpu,
91 device,
92 }
93 }
94
95 fn can_optimize_memory(&self) -> bool {
96 let info = self.memory_info();
97 info.total_bytes > 1024 && info.alignment < SIMD_ALIGNMENT
99 }
100
101 fn optimize_memory(&self) -> Self {
102 if !self.can_optimize_memory() {
103 return self.clone();
104 }
105
106 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
108 let shape = self.shape();
109 let len = self.numel();
110
111 if let Ok(ptr) = SimdAllocator::alloc_f32(len) {
113 unsafe {
114 let src = self.data.as_ptr();
116 let dst = ptr.as_ptr();
117 std::ptr::copy_nonoverlapping(src as *const f32, dst, len);
118
119 let aligned_data = Vec::from_raw_parts(dst, len, len);
121
122 let aligned_data_t: Vec<T> = std::mem::transmute(aligned_data);
124
125 match Self::try_from_vec(aligned_data_t, shape.to_vec()) {
127 Ok(tensor) => return tensor,
128 Err(_) => {
129 SimdAllocator::dealloc_f32(ptr, len);
131 }
132 }
133 }
134 }
135 }
136
137 self.clone()
139 }
140
141 fn try_optimize_memory(&self) -> RusTorchResult<Self> {
142 let info = self.memory_info();
143
144 const MAX_OPTIMIZE_SIZE: usize = 1_000_000_000; if info.total_bytes > MAX_OPTIMIZE_SIZE {
147 return Err(RusTorchError::TensorOp {
148 message: format!(
149 "Tensor too large to optimize: {} bytes exceeds maximum of {} bytes",
150 info.total_bytes, MAX_OPTIMIZE_SIZE
151 ),
152 source: None,
153 });
154 }
155
156 Ok(self.optimize_memory())
157 }
158}