1use crate::errors::{Result, TrustformersError};
2use crate::tensor::Tensor;
3use scirs2_core::ndarray::{s, IxDyn};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{Read, Seek, SeekFrom};
8use std::sync::{Arc, Mutex, RwLock};
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
21pub enum MemoryEvictionPolicy {
22 LRU,
24 LFU,
26 SizeBased,
28 ARC,
30 Hybrid,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum AdaptiveStrategy {
37 Fixed,
39 MemoryPressure,
41 HitRate,
43 Predictive,
45}
46
47#[derive(Debug, Clone)]
49pub struct MemoryConfig {
50 pub enable_memory_pool: bool,
52 pub max_pool_size: usize,
54 pub min_pool_size: usize,
56 pub enable_zero_copy: bool,
58 pub enable_mmap: bool,
60 pub mmap_threshold: usize,
62 pub cleanup_interval: Duration,
64 pub eviction_policy: MemoryEvictionPolicy,
66 pub adaptive_strategy: AdaptiveStrategy,
68 pub target_hit_rate: f64,
70 pub enable_prefetching: bool,
72 pub enable_defragmentation: bool,
74}
75
76impl Default for MemoryConfig {
77 fn default() -> Self {
78 Self {
79 enable_memory_pool: true,
80 max_pool_size: 1024 * 1024 * 1024, min_pool_size: 64 * 1024 * 1024, enable_zero_copy: true,
83 enable_mmap: true,
84 mmap_threshold: 100 * 1024 * 1024, cleanup_interval: Duration::from_secs(60),
86 eviction_policy: MemoryEvictionPolicy::Hybrid,
87 adaptive_strategy: AdaptiveStrategy::HitRate,
88 target_hit_rate: 0.85, enable_prefetching: true,
90 enable_defragmentation: true,
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
97struct PoolEntry {
98 tensor: Tensor,
99 last_used: Instant,
100 ref_count: usize,
101 access_count: usize,
103 #[allow(dead_code)]
105 created_at: Instant,
106 #[allow(dead_code)]
108 pool_time: Duration,
109 size_bytes: usize,
111}
112
113impl PoolEntry {
114 fn new(tensor: Tensor, size_bytes: usize) -> Self {
115 let now = Instant::now();
116 Self {
117 tensor,
118 last_used: now,
119 ref_count: 0,
120 access_count: 0,
121 created_at: now,
122 pool_time: Duration::ZERO,
123 size_bytes,
124 }
125 }
126
127 fn mark_accessed(&mut self) {
128 self.last_used = Instant::now();
129 self.access_count += 1;
130 }
131
132 fn eviction_priority(&self, policy: MemoryEvictionPolicy) -> f64 {
134 match policy {
135 MemoryEvictionPolicy::LRU => {
136 -(self.last_used.elapsed().as_secs_f64())
138 },
139 MemoryEvictionPolicy::LFU => {
140 -(self.access_count as f64)
142 },
143 MemoryEvictionPolicy::SizeBased => {
144 -(self.size_bytes as f64)
146 },
147 MemoryEvictionPolicy::ARC => {
148 let recency_score = 1.0 / (1.0 + self.last_used.elapsed().as_secs_f64());
150 let frequency_score = self.access_count as f64;
151 -(recency_score + frequency_score)
152 },
153 MemoryEvictionPolicy::Hybrid => {
154 let recency = 1.0 / (1.0 + self.last_used.elapsed().as_secs_f64());
156 let frequency = self.access_count as f64;
157 let size_factor = 1.0 / (1.0 + (self.size_bytes as f64 / 1_000_000.0));
158 -(recency * 0.4 + frequency * 0.4 + size_factor * 0.2)
159 },
160 }
161 }
162}
163
164#[derive(Debug)]
166pub struct TensorView {
167 original: Arc<Tensor>,
169 offset: usize,
171 shape: Vec<usize>,
173 #[allow(dead_code)]
175 strides: Vec<usize>,
176}
177
178impl TensorView {
179 pub fn slice(tensor: Arc<Tensor>, start: usize, end: usize) -> Result<Self> {
181 let original_shape = tensor.shape();
182 if start >= end || end > original_shape.iter().product::<usize>() {
183 return Err(TrustformersError::invalid_input(
184 "Invalid slice bounds".to_string(),
185 ));
186 }
187
188 let slice_len = end - start;
189 Ok(Self {
190 original: tensor,
191 offset: start,
192 shape: vec![slice_len],
193 strides: vec![1],
194 })
195 }
196
197 pub fn shape(&self) -> &[usize] {
199 &self.shape
200 }
201
202 pub fn as_tensor(&self) -> Result<Tensor> {
204 match &*self.original {
207 Tensor::F32(arr) => {
208 let flat = arr
209 .view()
210 .into_shape_with_order(arr.len())
211 .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
212 let slice = flat.slice(s![
213 self.offset..self.offset + self.shape.iter().product::<usize>()
214 ]);
215 let sliced_arr = slice
216 .to_owned()
217 .into_shape_with_order(IxDyn(&self.shape))
218 .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
219 Ok(Tensor::F32(sliced_arr))
220 },
221 _ => Err(TrustformersError::tensor_op_error(
222 "Zero-copy slicing not implemented for this tensor type",
223 "zero_copy_slice",
224 )),
225 }
226 }
227}
228
229#[derive(Debug, Clone)]
231struct PoolStatistics {
232 total_requests: usize,
233 cache_hits: usize,
234 cache_misses: usize,
235 total_evictions: usize,
236 evictions_by_policy: HashMap<String, usize>,
237 total_allocated_bytes: usize,
238 peak_memory_usage: usize,
239 #[allow(dead_code)]
240 average_tensor_lifetime: Duration,
241 #[allow(dead_code)]
242 last_reset: Instant,
243}
244
245impl Default for PoolStatistics {
246 fn default() -> Self {
247 Self {
248 total_requests: 0,
249 cache_hits: 0,
250 cache_misses: 0,
251 total_evictions: 0,
252 evictions_by_policy: HashMap::new(),
253 total_allocated_bytes: 0,
254 peak_memory_usage: 0,
255 average_tensor_lifetime: Duration::ZERO,
256 last_reset: Instant::now(),
257 }
258 }
259}
260
261impl PoolStatistics {
262 fn hit_rate(&self) -> f64 {
263 if self.total_requests == 0 {
264 0.0
265 } else {
266 self.cache_hits as f64 / self.total_requests as f64
267 }
268 }
269
270 fn miss_rate(&self) -> f64 {
271 if self.total_requests == 0 {
272 0.0
273 } else {
274 self.cache_misses as f64 / self.total_requests as f64
275 }
276 }
277}
278
279pub struct TensorMemoryPool {
281 config: MemoryConfig,
282 pool: Arc<RwLock<HashMap<Vec<usize>, Vec<PoolEntry>>>>,
283 current_size: Arc<Mutex<usize>>,
284 last_cleanup: Arc<Mutex<Instant>>,
285 statistics: Arc<Mutex<PoolStatistics>>,
287 access_patterns: Arc<Mutex<HashMap<Vec<usize>, Vec<Instant>>>>,
289 dynamic_max_size: Arc<Mutex<usize>>,
291}
292
293impl TensorMemoryPool {
294 pub fn new(config: MemoryConfig) -> Self {
296 let dynamic_max_size = config.max_pool_size;
297 Self {
298 config,
299 pool: Arc::new(RwLock::new(HashMap::new())),
300 current_size: Arc::new(Mutex::new(0)),
301 last_cleanup: Arc::new(Mutex::new(Instant::now())),
302 statistics: Arc::new(Mutex::new(PoolStatistics::default())),
303 access_patterns: Arc::new(Mutex::new(HashMap::new())),
304 dynamic_max_size: Arc::new(Mutex::new(dynamic_max_size)),
305 }
306 }
307
308 pub fn get_tensor(&self, shape: &[usize], dtype: crate::tensor::DType) -> Result<Tensor> {
310 if self.config.enable_prefetching {
312 let mut patterns = self.access_patterns.lock().expect("lock should not be poisoned");
313 patterns.entry(shape.to_vec()).or_default().push(Instant::now());
314 }
315
316 {
318 let mut stats = self.statistics.lock().expect("lock should not be poisoned");
319 stats.total_requests += 1;
320 }
321
322 if !self.config.enable_memory_pool {
323 return self.create_tensor(shape, dtype);
324 }
325
326 if let Some(tensor) = self.try_get_from_pool(shape)? {
328 let mut stats = self.statistics.lock().expect("lock should not be poisoned");
330 stats.cache_hits += 1;
331 return Ok(tensor);
332 }
333
334 {
336 let mut stats = self.statistics.lock().expect("lock should not be poisoned");
337 stats.cache_misses += 1;
338 }
339
340 self.apply_adaptive_sizing()?;
342
343 self.create_tensor(shape, dtype)
345 }
346
347 pub fn return_tensor(&self, tensor: Tensor) -> Result<()> {
349 if !self.config.enable_memory_pool {
350 return Ok(()); }
352
353 let shape = tensor.shape().to_vec();
354
355 let tensor_size = self.estimate_tensor_size(&tensor);
357
358 let entry = PoolEntry::new(tensor, tensor_size);
360
361 let mut pool = self.pool.write().expect("lock should not be poisoned");
362 pool.entry(shape).or_default().push(entry);
363
364 {
366 let mut current = self.current_size.lock().expect("lock should not be poisoned");
367 *current += tensor_size;
368
369 let mut stats = self.statistics.lock().expect("lock should not be poisoned");
370 if *current > stats.peak_memory_usage {
371 stats.peak_memory_usage = *current;
372 }
373 stats.total_allocated_bytes += tensor_size;
374 }
375
376 self.cleanup_if_needed()?;
378
379 Ok(())
380 }
381
382 fn try_get_from_pool(&self, shape: &[usize]) -> Result<Option<Tensor>> {
384 let mut pool = self.pool.write().expect("lock should not be poisoned");
385
386 if let Some(entries) = pool.get_mut(shape) {
387 if let Some(mut entry) = entries.pop() {
388 entry.mark_accessed();
390
391 let tensor_size = entry.size_bytes;
392 *self.current_size.lock().expect("lock should not be poisoned") -= tensor_size;
393 return Ok(Some(entry.tensor));
394 }
395 }
396
397 Ok(None)
398 }
399
400 fn create_tensor(&self, shape: &[usize], dtype: crate::tensor::DType) -> Result<Tensor> {
402 match dtype {
403 crate::tensor::DType::F32 => Tensor::zeros(shape),
404 crate::tensor::DType::F64 => Tensor::zeros_f64(shape),
405 crate::tensor::DType::F16 => Tensor::zeros_f16(shape),
406 crate::tensor::DType::BF16 => Tensor::zeros_bf16(shape),
407 crate::tensor::DType::I64 => Tensor::zeros_i64(shape),
408 crate::tensor::DType::C32 => Tensor::zeros_c32(shape),
409 crate::tensor::DType::C64 => Tensor::zeros_c64(shape),
410 crate::tensor::DType::CF16 => Tensor::zeros_cf16(shape),
411 crate::tensor::DType::CBF16 => Tensor::zeros_cbf16(shape),
412 _ => Err(TrustformersError::tensor_op_error(
413 &format!("Tensor creation not implemented for dtype: {:?} - only supported types are F32, F64, F16, BF16, I64, C32, C64, CF16, CBF16", dtype),
414 "create_tensor"
415 )),
416 }
417 }
418
419 fn estimate_tensor_size(&self, tensor: &Tensor) -> usize {
421 let elements = tensor.shape().iter().product::<usize>();
422 match tensor {
423 Tensor::F32(_) => elements * 4, Tensor::F64(_) => elements * 8, Tensor::F16(_) => elements * 2, Tensor::BF16(_) => elements * 2, Tensor::I64(_) => elements * 8, Tensor::C32(_) => elements * 8, Tensor::C64(_) => elements * 16, Tensor::CF16(_) => elements * 4, Tensor::CBF16(_) => elements * 4, #[cfg(feature = "torch")]
433 Tensor::Torch(_) => elements * 4, #[cfg(feature = "candle")]
435 Tensor::Candle(_) => elements * 4, #[cfg(all(target_os = "macos", feature = "metal"))]
437 Tensor::Metal(data) => elements * data.dtype.size_in_bytes(),
438 #[cfg(feature = "cuda")]
439 Tensor::CUDA(data) => elements * data.dtype.size_in_bytes(),
440 Tensor::Sparse(sparse) => {
441 let nnz = sparse.nnz();
443 nnz * 4 + nnz * std::mem::size_of::<usize>() },
445 }
446 }
447
448 fn cleanup_if_needed(&self) -> Result<()> {
450 let mut last_cleanup = self.last_cleanup.lock().expect("lock should not be poisoned");
451 let should_cleanup_time = last_cleanup.elapsed() >= self.config.cleanup_interval;
452
453 let current_size = *self.current_size.lock().expect("lock should not be poisoned");
454 let dynamic_max = *self.dynamic_max_size.lock().expect("lock should not be poisoned");
455 let should_cleanup_size = current_size > dynamic_max;
456
457 if !should_cleanup_time && !should_cleanup_size {
458 return Ok(());
459 }
460
461 let mut pool = self.pool.write().expect("lock should not be poisoned");
463 let mut total_freed = 0;
464 let mut eviction_count = 0;
465 let policy = self.config.eviction_policy;
466
467 let target_size = (dynamic_max as f64 * 0.85) as usize; let need_to_free = current_size.saturating_sub(target_size);
470
471 let mut all_entries: Vec<(Vec<usize>, usize, f64)> = Vec::new();
473
474 for (shape, entries) in pool.iter() {
475 for (idx, entry) in entries.iter().enumerate() {
476 if entry.ref_count == 0 {
477 let priority = entry.eviction_priority(policy);
478 all_entries.push((shape.clone(), idx, priority));
479 }
480 }
481 }
482
483 all_entries.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
485
486 let mut freed_so_far = 0;
488 let mut shapes_to_remove: Vec<Vec<usize>> = Vec::new();
489
490 for (shape, _, _) in all_entries.iter() {
491 if freed_so_far >= need_to_free {
492 break;
493 }
494
495 if let Some(entries) = pool.get_mut(shape) {
496 if let Some(entry) = entries.first() {
497 if entry.ref_count == 0 {
498 let size = entry.size_bytes;
499 freed_so_far += size;
500 total_freed += size;
501 eviction_count += 1;
502 shapes_to_remove.push(shape.clone());
503 }
504 }
505 }
506 }
507
508 for shape in shapes_to_remove {
510 if let Some(entries) = pool.get_mut(&shape) {
511 if !entries.is_empty() {
512 entries.remove(0);
513 }
514 }
515 }
516
517 pool.retain(|_, entries| !entries.is_empty());
519
520 drop(pool); {
524 let mut stats = self.statistics.lock().expect("lock should not be poisoned");
525 stats.total_evictions += eviction_count;
526 *stats.evictions_by_policy.entry(format!("{:?}", policy)).or_insert(0) +=
527 eviction_count;
528 }
529
530 *self.current_size.lock().expect("lock should not be poisoned") -= total_freed;
532 *last_cleanup = Instant::now();
533
534 if self.config.enable_defragmentation {
536 self.defragment_pool()?;
537 }
538
539 Ok(())
540 }
541
542 fn apply_adaptive_sizing(&self) -> Result<()> {
544 match self.config.adaptive_strategy {
545 AdaptiveStrategy::Fixed => Ok(()), AdaptiveStrategy::HitRate => self.adapt_by_hit_rate(),
547 AdaptiveStrategy::MemoryPressure => self.adapt_by_memory_pressure(),
548 AdaptiveStrategy::Predictive => self.adapt_by_prediction(),
549 }
550 }
551
552 fn adapt_by_hit_rate(&self) -> Result<()> {
554 let stats = self.statistics.lock().expect("lock should not be poisoned");
555 let hit_rate = stats.hit_rate();
556 drop(stats);
557
558 let mut dynamic_max = self.dynamic_max_size.lock().expect("lock should not be poisoned");
559 let target_rate = self.config.target_hit_rate;
560
561 if hit_rate < target_rate {
562 let increase = (*dynamic_max as f64 * 0.1) as usize;
564 let new_size = (*dynamic_max + increase).min(self.config.max_pool_size);
565 if new_size > *dynamic_max {
566 *dynamic_max = new_size;
567 }
568 } else if hit_rate > target_rate + 0.1 {
569 let decrease = (*dynamic_max as f64 * 0.05) as usize;
571 let new_size = (*dynamic_max - decrease).max(self.config.min_pool_size);
572 if new_size < *dynamic_max {
573 *dynamic_max = new_size;
574 }
575 }
576
577 Ok(())
578 }
579
580 fn adapt_by_memory_pressure(&self) -> Result<()> {
582 let current_size = *self.current_size.lock().expect("lock should not be poisoned");
585 let mut dynamic_max = self.dynamic_max_size.lock().expect("lock should not be poisoned");
586
587 let utilization = current_size as f64 / *dynamic_max as f64;
588
589 if utilization > 0.9 {
590 let new_size = (*dynamic_max as f64 * 0.9) as usize;
592 *dynamic_max = new_size.max(self.config.min_pool_size);
593 } else if utilization < 0.5 {
594 let new_size = (*dynamic_max as f64 * 1.1) as usize;
596 *dynamic_max = new_size.min(self.config.max_pool_size);
597 }
598
599 Ok(())
600 }
601
602 fn adapt_by_prediction(&self) -> Result<()> {
604 let patterns = self.access_patterns.lock().expect("lock should not be poisoned");
605
606 let mut total_recent_accesses = 0;
608 let recent_window = Duration::from_secs(60);
609 let now = Instant::now();
610
611 for timestamps in patterns.values() {
612 total_recent_accesses +=
613 timestamps.iter().filter(|t| now.duration_since(**t) < recent_window).count();
614 }
615
616 drop(patterns);
617
618 let mut dynamic_max = self.dynamic_max_size.lock().expect("lock should not be poisoned");
620
621 if total_recent_accesses > 1000 {
622 let new_size = (*dynamic_max as f64 * 1.15) as usize;
624 *dynamic_max = new_size.min(self.config.max_pool_size);
625 } else if total_recent_accesses < 100 {
626 let new_size = (*dynamic_max as f64 * 0.9) as usize;
628 *dynamic_max = new_size.max(self.config.min_pool_size);
629 }
630
631 Ok(())
632 }
633
634 fn defragment_pool(&self) -> Result<()> {
636 let mut pool = self.pool.write().expect("lock should not be poisoned");
638
639 for entries in pool.values_mut() {
640 entries.sort_by_key(|entry| std::cmp::Reverse(entry.access_count));
642 }
643
644 Ok(())
645 }
646
647 pub fn get_stats(&self) -> MemoryPoolStats {
649 let pool = self.pool.read().expect("lock should not be poisoned");
650 let current_size = *self.current_size.lock().expect("lock should not be poisoned");
651 let stats = self.statistics.lock().expect("lock should not be poisoned");
652 let dynamic_max = *self.dynamic_max_size.lock().expect("lock should not be poisoned");
653
654 let total_tensors = pool.values().map(|v| v.len()).sum();
655 let total_shapes = pool.len();
656
657 MemoryPoolStats {
658 total_tensors,
659 total_shapes,
660 current_size_bytes: current_size,
661 max_size_bytes: self.config.max_pool_size,
662 dynamic_max_size_bytes: dynamic_max,
663 utilization: current_size as f64 / dynamic_max as f64,
664 hit_rate: stats.hit_rate(),
665 miss_rate: stats.miss_rate(),
666 total_requests: stats.total_requests,
667 cache_hits: stats.cache_hits,
668 cache_misses: stats.cache_misses,
669 total_evictions: stats.total_evictions,
670 peak_memory_usage_bytes: stats.peak_memory_usage,
671 eviction_policy: self.config.eviction_policy,
672 adaptive_strategy: self.config.adaptive_strategy,
673 }
674 }
675
676 pub fn reset_statistics(&self) {
678 let mut stats = self.statistics.lock().expect("Lock poisoned");
679 *stats = PoolStatistics::default();
680 }
681
682 pub fn hit_rate(&self) -> f64 {
684 let stats = self.statistics.lock().expect("lock should not be poisoned");
685 stats.hit_rate()
686 }
687
688 pub fn eviction_policy(&self) -> MemoryEvictionPolicy {
690 self.config.eviction_policy
691 }
692
693 pub fn adaptive_strategy(&self) -> AdaptiveStrategy {
695 self.config.adaptive_strategy
696 }
697
698 pub fn get_predicted_shapes(&self, window: Duration) -> Vec<Vec<usize>> {
700 let patterns = self.access_patterns.lock().expect("lock should not be poisoned");
701 let now = Instant::now();
702
703 let mut frequent_shapes: Vec<(Vec<usize>, usize)> = patterns
704 .iter()
705 .map(|(shape, timestamps)| {
706 let count = timestamps.iter().filter(|t| now.duration_since(**t) < window).count();
707 (shape.clone(), count)
708 })
709 .filter(|(_, count)| *count > 0)
710 .collect();
711
712 frequent_shapes.sort_by_key(|item| std::cmp::Reverse(item.1));
713 frequent_shapes.into_iter().map(|(shape, _)| shape).collect()
714 }
715}
716
717#[derive(Debug, Clone)]
719pub struct MemoryPoolStats {
720 pub total_tensors: usize,
722 pub total_shapes: usize,
724 pub current_size_bytes: usize,
726 pub max_size_bytes: usize,
728 pub dynamic_max_size_bytes: usize,
730 pub utilization: f64,
732 pub hit_rate: f64,
734 pub miss_rate: f64,
736 pub total_requests: usize,
738 pub cache_hits: usize,
740 pub cache_misses: usize,
742 pub total_evictions: usize,
744 pub peak_memory_usage_bytes: usize,
746 pub eviction_policy: MemoryEvictionPolicy,
748 pub adaptive_strategy: AdaptiveStrategy,
750}
751
752pub struct MemoryMappedTensor {
754 file_path: String,
756 shape: Vec<usize>,
758 dtype: crate::tensor::DType,
760 _file: Option<File>,
762 file_size: u64,
764}
765
766impl MemoryMappedTensor {
767 pub fn new(file_path: String, shape: Vec<usize>, dtype: crate::tensor::DType) -> Result<Self> {
769 let mut file = File::open(&file_path).map_err(|e| {
771 TrustformersError::tensor_op_error(
772 &format!("Failed to open file for memory mapping: {}", e),
773 "mmap_new",
774 )
775 })?;
776
777 let file_size = file.seek(SeekFrom::End(0)).map_err(|e| {
779 TrustformersError::tensor_op_error(
780 &format!("Failed to get file size: {}", e),
781 "mmap_new",
782 )
783 })?;
784
785 let element_size = dtype.size_in_bytes();
787 let total_elements: usize = shape.iter().product();
788 let expected_size = total_elements * element_size;
789
790 if file_size != expected_size as u64 {
791 return Err(TrustformersError::tensor_op_error(
792 &format!(
793 "File size {} doesn't match expected tensor size {}",
794 file_size, expected_size
795 ),
796 "mmap_new",
797 ));
798 }
799
800 Ok(Self {
801 file_path,
802 shape,
803 dtype,
804 _file: Some(file),
805 file_size,
806 })
807 }
808
809 pub fn load(&self) -> Result<Tensor> {
811 let mut file = File::open(&self.file_path).map_err(|e| {
813 TrustformersError::tensor_op_error(
814 &format!("Failed to open file for reading: {}", e),
815 "mmap_load",
816 )
817 })?;
818
819 let mut buffer = vec![0u8; self.file_size as usize];
820 file.read_exact(&mut buffer).map_err(|e| {
821 TrustformersError::tensor_op_error(
822 &format!("Failed to read file data: {}", e),
823 "mmap_load",
824 )
825 })?;
826
827 match self.dtype {
829 crate::tensor::DType::F32 => {
830 let float_data = buffer
831 .chunks_exact(4)
832 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
833 .collect::<Vec<f32>>();
834 Tensor::from_slice(&float_data, &self.shape)
835 },
836 crate::tensor::DType::F64 => {
837 let float_data = buffer
838 .chunks_exact(8)
839 .map(|chunk| {
840 f64::from_le_bytes([
841 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
842 chunk[7],
843 ])
844 })
845 .collect::<Vec<f64>>();
846 Tensor::from_slice_f64(&float_data, &self.shape)
847 },
848 crate::tensor::DType::I64 => {
849 let int_data = buffer
850 .chunks_exact(8)
851 .map(|chunk| {
852 i64::from_le_bytes([
853 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
854 chunk[7],
855 ])
856 })
857 .collect::<Vec<i64>>();
858 Tensor::from_slice_i64(&int_data, &self.shape)
859 },
860 crate::tensor::DType::I32 => {
861 let int_data = buffer
862 .chunks_exact(4)
863 .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
864 .collect::<Vec<i32>>();
865 Tensor::from_slice_i32(&int_data, &self.shape)
866 },
867 _ => Err(TrustformersError::tensor_op_error(
868 "Unsupported dtype for memory mapped tensor",
869 "mmap_load",
870 )),
871 }
872 }
873
874 pub fn shape(&self) -> &[usize] {
876 &self.shape
877 }
878
879 pub fn file_path(&self) -> &str {
881 &self.file_path
882 }
883}
884
885static MEMORY_MANAGER: std::sync::OnceLock<TensorMemoryPool> = std::sync::OnceLock::new();
887
888pub fn init_memory_manager(config: MemoryConfig) -> Result<()> {
890 let pool = TensorMemoryPool::new(config);
891 MEMORY_MANAGER.set(pool).map_err(|_| {
892 TrustformersError::invalid_input("Memory manager already initialized".to_string())
893 })?;
894 Ok(())
895}
896
897pub fn get_memory_manager() -> Option<&'static TensorMemoryPool> {
899 MEMORY_MANAGER.get()
900}
901
902pub fn get_tensor(shape: &[usize], dtype: crate::tensor::DType) -> Result<Tensor> {
904 if let Some(manager) = get_memory_manager() {
905 manager.get_tensor(shape, dtype)
906 } else {
907 match dtype {
909 crate::tensor::DType::F32 => Tensor::zeros(shape),
910 crate::tensor::DType::F64 => Tensor::zeros_f64(shape),
911 crate::tensor::DType::I64 => Tensor::zeros_i64(shape),
912 _ => Err(TrustformersError::tensor_op_error(
913 "Unsupported dtype",
914 "get_tensor",
915 )),
916 }
917 }
918}
919
920pub fn return_tensor(tensor: Tensor) -> Result<()> {
922 if let Some(manager) = get_memory_manager() {
923 manager.return_tensor(tensor)
924 } else {
925 Ok(()) }
927}
928
929#[cfg(test)]
930mod tests {
931 use super::*;
932
933 #[test]
934 fn test_memory_config_default() {
935 let config = MemoryConfig::default();
936 assert!(config.enable_memory_pool);
937 assert!(config.enable_zero_copy);
938 assert!(config.enable_mmap);
939 assert_eq!(config.max_pool_size, 1024 * 1024 * 1024);
940 }
941
942 #[test]
943 fn test_tensor_pool_creation() {
944 let config = MemoryConfig::default();
945 let pool = TensorMemoryPool::new(config);
946 let stats = pool.get_stats();
947 assert_eq!(stats.total_tensors, 0);
948 assert_eq!(stats.current_size_bytes, 0);
949 }
950
951 #[test]
952 fn test_tensor_pool_get_and_return() -> Result<()> {
953 let config = MemoryConfig::default();
954 let pool = TensorMemoryPool::new(config);
955
956 let shape = vec![2, 3];
958 let tensor = pool.get_tensor(&shape, crate::tensor::DType::F32)?;
959 assert_eq!(tensor.shape(), shape.as_slice());
960
961 pool.return_tensor(tensor)?;
963
964 let tensor2 = pool.get_tensor(&shape, crate::tensor::DType::F32)?;
966 assert_eq!(tensor2.shape(), shape.as_slice());
967
968 Ok(())
969 }
970
971 #[test]
972 fn test_zero_copy_tensor_view() -> Result<()> {
973 let tensor = Arc::new(Tensor::ones(&[10])?);
974 let view = TensorView::slice(tensor, 2, 8)?;
975 assert_eq!(view.shape(), &[6]);
976
977 let viewed_tensor = view.as_tensor()?;
978 assert_eq!(viewed_tensor.shape(), &[6]);
979
980 Ok(())
981 }
982
983 #[test]
984 fn test_memory_mapped_tensor() -> Result<()> {
985 use std::fs::File;
986 use std::io::Write;
987
988 let temp_file = "test_temp.bin";
990 let data_size = 100 * 100 * std::mem::size_of::<f32>();
991 let data: Vec<u8> = vec![0; data_size];
992
993 {
994 let mut file = File::create(temp_file).map_err(|e| {
995 TrustformersError::tensor_op_error(
996 &format!("Failed to create test file: {}", e),
997 "test_setup",
998 )
999 })?;
1000 file.write_all(&data).map_err(|e| {
1001 TrustformersError::tensor_op_error(
1002 &format!("Failed to write test data: {}", e),
1003 "test_setup",
1004 )
1005 })?;
1006 }
1007
1008 let mmap_tensor = MemoryMappedTensor::new(
1009 temp_file.to_string(),
1010 vec![100, 100],
1011 crate::tensor::DType::F32,
1012 )?;
1013
1014 assert_eq!(mmap_tensor.shape(), &[100, 100]);
1015 assert_eq!(mmap_tensor.file_path(), temp_file);
1016
1017 let loaded = mmap_tensor.load()?;
1018 assert_eq!(loaded.shape(), &[100, 100]);
1019
1020 std::fs::remove_file(temp_file).ok();
1022
1023 Ok(())
1024 }
1025
1026 #[test]
1027 fn test_global_memory_manager() -> Result<()> {
1028 let config = MemoryConfig::default();
1029 init_memory_manager(config)?;
1030
1031 let tensor = get_tensor(&[5, 5], crate::tensor::DType::F32)?;
1032 assert_eq!(tensor.shape(), [5, 5].as_slice());
1033
1034 return_tensor(tensor)?;
1035
1036 Ok(())
1037 }
1038}