rustorch/tensor/memory/
optimization.rs

1//! Memory optimization utilities for tensor operations
2//! テンソル操作のためのメモリ最適化ユーティリティ
3
4use super::aligned::{SimdAllocator, SIMD_ALIGNMENT};
5use crate::error::{RusTorchError, RusTorchResult};
6use crate::tensor::Tensor;
7use num_traits::Float;
8
9/// Memory information for tensors
10/// テンソルのメモリ情報
11#[derive(Debug, Clone)]
12pub struct TensorMemoryInfo {
13    /// Total number of elements in the tensor
14    /// テンソルの総要素数
15    pub total_elements: usize,
16    /// Size of each element in bytes
17    /// 各要素のバイトサイズ
18    pub element_size: usize,
19    /// Total memory used in bytes
20    /// 使用メモリの総バイト数
21    pub total_bytes: usize,
22    /// Whether the tensor data is contiguous in memory
23    /// テンソルデータがメモリ上で連続しているか
24    pub is_contiguous: bool,
25    /// Memory alignment in bytes
26    /// メモリアライメント(バイト)
27    pub alignment: usize,
28    /// Whether the tensor is stored on GPU
29    /// テンソルがGPU上に保存されているか
30    pub is_on_gpu: bool,
31    /// Device type string
32    /// デバイスタイプ文字列
33    pub device: String,
34}
35
36/// Memory optimization trait for tensors
37/// テンソルのメモリ最適化トレイト
38pub trait MemoryOptimization<T: Float> {
39    /// Get memory information about this tensor
40    /// このテンソルのメモリ情報を取得
41    fn memory_info(&self) -> TensorMemoryInfo;
42
43    /// Check if this tensor can be optimized for memory usage
44    /// このテンソルがメモリ使用量を最適化できるかチェック
45    fn can_optimize_memory(&self) -> bool;
46
47    /// Create a memory-optimized copy of this tensor
48    /// このテンソルのメモリ最適化コピーを作成
49    fn optimize_memory(&self) -> Self;
50
51    /// Try to create a memory-optimized copy with error handling
52    /// エラーハンドリング付きでメモリ最適化コピーを作成を試行
53    fn try_optimize_memory(&self) -> RusTorchResult<Self>
54    where
55        Self: Sized;
56}
57
58impl<T: Float + Clone + 'static> MemoryOptimization<T> for Tensor<T> {
59    fn memory_info(&self) -> TensorMemoryInfo {
60        let element_size = std::mem::size_of::<T>();
61        let total_elements = self.data.len();
62        let total_bytes = total_elements * element_size;
63
64        // Check for SIMD alignment
65        let ptr = self.data.as_ptr();
66        let alignment = if SimdAllocator::is_aligned(ptr) {
67            SIMD_ALIGNMENT
68        } else {
69            // Check for standard alignments
70            if (ptr as usize) % 16 == 0 {
71                16
72            } else if (ptr as usize) % 8 == 0 {
73                8
74            } else if (ptr as usize) % 4 == 0 {
75                4
76            } else {
77                1
78            }
79        };
80
81        let is_on_gpu = self.is_on_gpu();
82        let device = self.device_type().to_string();
83
84        TensorMemoryInfo {
85            total_elements,
86            element_size,
87            total_bytes,
88            is_contiguous: self.data.is_standard_layout(),
89            alignment,
90            is_on_gpu,
91            device,
92        }
93    }
94
95    fn can_optimize_memory(&self) -> bool {
96        let info = self.memory_info();
97        // Can optimize if tensor is large enough and not already SIMD-aligned
98        info.total_bytes > 1024 && info.alignment < SIMD_ALIGNMENT
99    }
100
101    fn optimize_memory(&self) -> Self {
102        if !self.can_optimize_memory() {
103            return self.clone();
104        }
105
106        // Try to create SIMD-aligned copy for f32 tensors
107        if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
108            let shape = self.shape();
109            let len = self.numel();
110
111            // Allocate SIMD-aligned memory
112            if let Ok(ptr) = SimdAllocator::alloc_f32(len) {
113                unsafe {
114                    // Copy data to aligned memory
115                    let src = self.data.as_ptr();
116                    let dst = ptr.as_ptr();
117                    std::ptr::copy_nonoverlapping(src as *const f32, dst, len);
118
119                    // Create vector from aligned memory
120                    let aligned_data = Vec::from_raw_parts(dst, len, len);
121
122                    // Convert to T type (this is safe because we checked TypeId)
123                    let aligned_data_t: Vec<T> = std::mem::transmute(aligned_data);
124
125                    // Create tensor from aligned data
126                    match Self::try_from_vec(aligned_data_t, shape.to_vec()) {
127                        Ok(tensor) => return tensor,
128                        Err(_) => {
129                            // If creation fails, deallocate and fall back to clone
130                            SimdAllocator::dealloc_f32(ptr, len);
131                        }
132                    }
133                }
134            }
135        }
136
137        // Fallback to regular clone if SIMD optimization fails
138        self.clone()
139    }
140
141    fn try_optimize_memory(&self) -> RusTorchResult<Self> {
142        let info = self.memory_info();
143
144        // Check if tensor is too large to optimize safely
145        const MAX_OPTIMIZE_SIZE: usize = 1_000_000_000; // 1GB
146        if info.total_bytes > MAX_OPTIMIZE_SIZE {
147            return Err(RusTorchError::TensorOp {
148                message: format!(
149                    "Tensor too large to optimize: {} bytes exceeds maximum of {} bytes",
150                    info.total_bytes, MAX_OPTIMIZE_SIZE
151                ),
152                source: None,
153            });
154        }
155
156        Ok(self.optimize_memory())
157    }
158}