1use crate::error::{NeuralError, Result};
11use crate::layers::Layer;
12use ndarray::{Array, ArrayD, ArrayView, IxDyn};
13use num_traits::Float;
14use std::collections::{HashMap, VecDeque};
15use std::fmt::Debug;
16use std::sync::{Arc, Mutex, RwLock};
17
18#[cfg(feature = "memory_efficient")]
19#[cfg(feature = "memory_management")]
22use scirs2_core::memory_efficient::BufferPool;
23
24#[cfg(feature = "cache")]
39use scirs2_core::cache::{CacheBuilder, TTLSizedCache};
40
41#[derive(Debug, Clone)]
43pub struct MemoryUsage {
44 pub current_bytes: usize,
46 pub peak_bytes: usize,
48 pub active_allocations: usize,
50 pub total_allocations: usize,
52}
53
54impl MemoryUsage {
55 pub fn new() -> Self {
57 Self {
58 current_bytes: 0,
59 peak_bytes: 0,
60 active_allocations: 0,
61 total_allocations: 0,
62 }
63 }
64
65 pub fn allocate(&mut self, bytes: usize) {
67 self.current_bytes += bytes;
68 self.peak_bytes = self.peak_bytes.max(self.current_bytes);
69 self.active_allocations += 1;
70 self.total_allocations += 1;
71 }
72
73 pub fn deallocate(&mut self, bytes: usize) {
75 self.current_bytes = self.current_bytes.saturating_sub(bytes);
76 self.active_allocations = self.active_allocations.saturating_sub(1);
77 }
78
79 pub fn current_mb(&self) -> f64 {
81 self.current_bytes as f64 / (1024.0 * 1024.0)
82 }
83
84 pub fn peak_mb(&self) -> f64 {
86 self.peak_bytes as f64 / (1024.0 * 1024.0)
87 }
88}
89
90impl Default for MemoryUsage {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96pub struct MemoryPool<F: Float + Debug> {
98 available_tensors: HashMap<Vec<usize>, VecDeque<ArrayD<F>>>,
100 usage: Arc<Mutex<MemoryUsage>>,
102 max_pool_size: usize,
104 current_pool_size: usize,
106}
107
108impl<F: Float + Debug + Clone + 'static> MemoryPool<F> {
109 pub fn new(max_pool_size_mb: usize) -> Self {
111 Self {
112 available_tensors: HashMap::new(),
113 usage: Arc::new(Mutex::new(MemoryUsage::new())),
114 max_pool_size: max_pool_size_mb * 1024 * 1024,
115 current_pool_size: 0,
116 }
117 }
118
119 pub fn allocate(&mut self, shape: &[usize]) -> ArrayD<F> {
121 let shape_vec = shape.to_vec();
122
123 if let Some(tensors) = self.available_tensors.get_mut(&shape_vec) {
125 if let Some(mut tensor) = tensors.pop_front() {
126 tensor.fill(F::zero());
128
129 if let Ok(mut usage) = self.usage.lock() {
131 let bytes = Self::calculate_bytes(&shape_vec);
132 usage.allocate(bytes);
133 }
134
135 return tensor;
136 }
137 }
138
139 let tensor = Array::zeros(IxDyn(shape));
141
142 if let Ok(mut usage) = self.usage.lock() {
144 let bytes = Self::calculate_bytes(&shape_vec);
145 usage.allocate(bytes);
146 }
147
148 tensor
149 }
150
151 pub fn deallocate(&mut self, tensor: ArrayD<F>) {
153 let shape = tensor.shape().to_vec();
154 let bytes = Self::calculate_bytes(&shape);
155
156 if self.current_pool_size + bytes <= self.max_pool_size {
158 self.available_tensors
159 .entry(shape)
160 .or_default()
161 .push_back(tensor);
162 self.current_pool_size += bytes;
163 }
164
165 if let Ok(mut usage) = self.usage.lock() {
167 usage.deallocate(bytes);
168 }
169 }
170
171 pub fn get_usage(&self) -> MemoryUsage {
173 self.usage
174 .lock()
175 .unwrap_or_else(|poisoned| poisoned.into_inner())
176 .clone()
177 }
178
179 pub fn clear(&mut self) {
181 self.available_tensors.clear();
182 self.current_pool_size = 0;
183 }
184
185 fn calculate_bytes(shape: &[usize]) -> usize {
187 let elements: usize = shape.iter().product();
188 elements * std::mem::size_of::<F>()
189 }
190
191 pub fn get_pool_stats(&self) -> PoolStatistics {
193 let total_tensors: usize = self.available_tensors.values().map(|v| v.len()).sum();
194 let unique_shapes = self.available_tensors.len();
195
196 PoolStatistics {
197 total_cached_tensors: total_tensors,
198 unique_shapes,
199 current_pool_size_mb: self.current_pool_size as f64 / (1024.0 * 1024.0),
200 max_pool_size_mb: self.max_pool_size as f64 / (1024.0 * 1024.0),
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct PoolStatistics {
208 pub total_cached_tensors: usize,
210 pub unique_shapes: usize,
212 pub current_pool_size_mb: f64,
214 pub max_pool_size_mb: f64,
216}
217
218pub struct GradientCheckpointing<F: Float + Debug> {
220 checkpoint_layers: Vec<String>,
222 checkpoints: HashMap<String, ArrayD<F>>,
224 memory_threshold_mb: f64,
226 memory_usage: Arc<RwLock<MemoryUsage>>,
228}
229
230impl<F: Float + Debug + Clone + 'static + ndarray::ScalarOperand> GradientCheckpointing<F> {
231 pub fn new(memory_threshold_mb: f64) -> Self {
233 Self {
234 checkpoint_layers: Vec::new(),
235 checkpoints: HashMap::new(),
236 memory_threshold_mb,
237 memory_usage: Arc::new(RwLock::new(MemoryUsage::new())),
238 }
239 }
240
241 pub fn add_checkpoint_layer(&mut self, layer_name: String) {
243 self.checkpoint_layers.push(layer_name);
244 }
245
246 pub fn store_checkpoint(&mut self, layer_name: &str, activation: ArrayD<F>) -> Result<()> {
248 if self.checkpoint_layers.contains(&layer_name.to_string()) {
249 let bytes = activation.len() * std::mem::size_of::<F>();
251
252 if let Ok(mut usage) = self.memory_usage.write() {
253 usage.allocate(bytes);
254
255 if usage.current_mb() > self.memory_threshold_mb {
257 return Err(NeuralError::ComputationError(format!(
258 "Memory threshold exceeded: {:.2}MB > {:.2}MB",
259 usage.current_mb(),
260 self.memory_threshold_mb
261 )));
262 }
263 }
264
265 self.checkpoints.insert(layer_name.to_string(), activation);
266 }
267 Ok(())
268 }
269
270 pub fn get_checkpoint(&self, layer_name: &str) -> Option<&ArrayD<F>> {
272 self.checkpoints.get(layer_name)
273 }
274
275 pub fn clear_checkpoints(&mut self) {
277 let total_bytes: usize = self
278 .checkpoints
279 .values()
280 .map(|arr| arr.len() * std::mem::size_of::<F>())
281 .sum();
282
283 self.checkpoints.clear();
284
285 if let Ok(mut usage) = self.memory_usage.write() {
286 usage.deallocate(total_bytes);
287 }
288 }
289
290 pub fn get_memory_usage(&self) -> MemoryUsage {
292 self.memory_usage
293 .read()
294 .map(|usage| usage.clone())
295 .unwrap_or_default()
296 }
297
298 pub fn recompute_from_checkpoint<L>(
300 &self,
301 layers: &[L],
302 start_layer: &str,
303 _target_layer: &str,
304 _input: &ArrayD<F>,
305 ) -> Result<ArrayD<F>>
306 where
307 L: Layer<F>,
308 {
309 let checkpoint_activation = self.get_checkpoint(start_layer).ok_or_else(|| {
311 NeuralError::ComputationError(format!("No checkpoint found for layer: {}", start_layer))
312 })?;
313
314 let mut current_activation = checkpoint_activation.clone();
316
317 for layer in layers {
320 current_activation = layer.forward(¤t_activation)?;
321 }
322
323 Ok(current_activation)
324 }
325}
326
327pub struct InPlaceOperations;
329
330impl InPlaceOperations {
331 pub fn relu_inplace<F: Float + Debug>(array: &mut ArrayD<F>) {
333 array.mapv_inplace(|x| x.max(F::zero()));
334 }
335
336 pub fn sigmoid_inplace<F: Float + Debug>(array: &mut ArrayD<F>) {
338 array.mapv_inplace(|x| F::one() / (F::one() + (-x).exp()));
339 }
340
341 pub fn tanh_inplace<F: Float + Debug>(array: &mut ArrayD<F>) {
343 array.mapv_inplace(|x| x.tanh());
344 }
345
346 pub fn add_inplace<F: Float + Debug>(target: &mut ArrayD<F>, source: &ArrayD<F>) -> Result<()> {
348 if target.shape() != source.shape() {
349 return Err(NeuralError::ShapeMismatch(
350 "Arrays must have the same shape for in-place addition".to_string(),
351 ));
352 }
353
354 for (t, &s) in target.iter_mut().zip(source.iter()) {
355 *t = *t + s;
356 }
357
358 Ok(())
359 }
360
361 pub fn scale_inplace<F: Float + Debug>(array: &mut ArrayD<F>, factor: F) {
363 array.mapv_inplace(|x| x * factor);
364 }
365
366 pub fn normalize_inplace<F: Float + Debug + Clone + num_traits::FromPrimitive>(
368 array: &mut ArrayD<F>,
369 ) -> Result<()> {
370 let mean = array.mean().unwrap_or(F::zero());
371 let variance = array
372 .mapv(|x| (x - mean) * (x - mean))
373 .mean()
374 .unwrap_or(F::zero());
375 let std_dev = variance.sqrt();
376
377 if std_dev == F::zero() {
378 return Ok(()); }
380
381 array.mapv_inplace(|x| (x - mean) / std_dev);
382 Ok(())
383 }
384
385 pub fn dropout_inplace<F: Float + Debug>(
387 array: &mut ArrayD<F>,
388 dropout_rate: f64,
389 training: bool,
390 ) -> Result<()> {
391 if !training {
392 return Ok(());
393 }
394
395 let keep_prob = 1.0 - dropout_rate;
396 let scale_factor = F::from(1.0 / keep_prob).unwrap();
397
398 for element in array.iter_mut() {
399 if rand::random::<f64>() < dropout_rate {
400 *element = F::zero();
401 } else {
402 *element = *element * scale_factor;
403 }
404 }
405
406 Ok(())
407 }
408}
409
410pub struct MemoryAwareBatchProcessor<F: Float + Debug> {
412 max_batch_size: usize,
414 memory_pool: MemoryPool<F>,
416 memory_threshold_mb: f64,
418}
419
420impl<F: Float + Debug + Clone + 'static> MemoryAwareBatchProcessor<F> {
421 pub fn new(max_memory_mb: usize, memory_threshold_mb: f64, pool_size_mb: usize) -> Self {
423 Self {
424 max_batch_size: Self::calculate_max_batch_size(max_memory_mb),
425 memory_pool: MemoryPool::new(pool_size_mb),
426 memory_threshold_mb,
427 }
428 }
429
430 pub fn process_batches<ProcessFn>(
432 &mut self,
433 input: &ArrayD<F>,
434 mut process_fn: ProcessFn,
435 ) -> Result<Vec<ArrayD<F>>>
436 where
437 ProcessFn: FnMut(&ArrayView<F, IxDyn>) -> Result<ArrayD<F>>,
438 {
439 let total_samples = input.shape()[0];
440 let mut results = Vec::new();
441 let mut start_idx = 0;
442
443 while start_idx < total_samples {
444 let current_usage = self.memory_pool.get_usage();
446 let available_memory_mb = self.memory_threshold_mb - current_usage.current_mb();
447
448 let batch_size = if available_memory_mb < 100.0 {
449 (self.max_batch_size / 4).max(1)
451 } else if available_memory_mb < 200.0 {
452 self.max_batch_size / 2
454 } else {
455 self.max_batch_size
457 };
458
459 let end_idx = (start_idx + batch_size).min(total_samples);
460 let batch = input.slice(ndarray::s![start_idx..end_idx, ..]).into_dyn();
461
462 let result = process_fn(&batch)?;
464 results.push(result);
465
466 start_idx = end_idx;
467
468 if current_usage.current_mb() > self.memory_threshold_mb * 0.8 {
470 self.memory_pool.clear();
471 }
472 }
473
474 Ok(results)
475 }
476
477 fn calculate_max_batch_size(max_memory_mb: usize) -> usize {
479 let max_memory_bytes = max_memory_mb * 1024 * 1024;
481 let bytes_per_sample = 1024; (max_memory_bytes / bytes_per_sample).max(1)
483 }
484
485 pub fn get_stats(&self) -> BatchProcessorStats {
487 let usage = self.memory_pool.get_usage();
488 let pool_stats = self.memory_pool.get_pool_stats();
489
490 BatchProcessorStats {
491 max_batch_size: self.max_batch_size,
492 current_memory_mb: usage.current_mb(),
493 peak_memory_mb: usage.peak_mb(),
494 memory_threshold_mb: self.memory_threshold_mb,
495 pool_stats,
496 }
497 }
498}
499
500#[derive(Debug, Clone)]
502pub struct BatchProcessorStats {
503 pub max_batch_size: usize,
505 pub current_memory_mb: f64,
507 pub peak_memory_mb: f64,
509 pub memory_threshold_mb: f64,
511 pub pool_stats: PoolStatistics,
513}
514
515pub struct MemoryEfficientLayer {
517 #[cfg(feature = "memory_efficient")]
519 #[allow(dead_code)]
520 weights: ArrayD<f32>,
521
522 bias: ndarray::Array1<f32>,
524
525 chunk_size: usize,
527
528 #[cfg(feature = "memory_management")]
535 #[allow(dead_code)]
536 buffer_pool: Arc<BufferPool>,
537
538 #[cfg(feature = "cache")]
540 activation_cache: TTLSizedCache<String, ArrayD<f32>>,
541}
542
543impl MemoryEfficientLayer {
544 pub fn new(input_size: usize, output_size: usize, chunk_size: Option<usize>) -> Result<Self> {
546 let _weights_shape = [input_size, output_size];
547 let default_chunk_size = chunk_size.unwrap_or(1024);
548
549 #[cfg(feature = "memory_efficient")]
550 let weights = ArrayD::zeros(IxDyn(&_weights_shape));
551
552 let bias = ndarray::Array1::zeros(output_size);
553
554 #[cfg(feature = "memory_management")]
562 let buffer_pool = Arc::new(
563 BufferPool::new(
564 1000, default_chunk_size * output_size, false, 64, )
569 .unwrap(),
570 );
571
572 #[cfg(feature = "cache")]
573 let activation_cache = CacheBuilder::new()
574 .with_size(100)
575 .with_ttl(300)
576 .build_sized_cache();
577
578 Ok(Self {
579 #[cfg(feature = "memory_efficient")]
580 weights,
581 bias,
582 chunk_size: default_chunk_size,
583 #[cfg(feature = "memory_management")]
587 buffer_pool,
588 #[cfg(feature = "cache")]
589 activation_cache,
590 })
591 }
592
593 pub fn forward(&self, input: &ArrayD<f32>) -> Result<ArrayD<f32>> {
595 let input_shape = input.shape();
596 let batch_size = input_shape[0];
597 let _input_size = input_shape[1];
598 let output_size = self.bias.len();
599
600 let mut output = Array::zeros((batch_size, output_size));
602
603 let chunks = batch_size.div_ceil(self.chunk_size);
605
606 for chunk_idx in 0..chunks {
607 let start_idx = chunk_idx * self.chunk_size;
608 let end_idx = std::cmp::min(start_idx + self.chunk_size, batch_size);
609 let _chunk_batch_size = end_idx - start_idx;
610
611 let input_chunk = input.slice(ndarray::s![start_idx..end_idx, ..]);
613
614 #[cfg(feature = "memory_efficient")]
616 let chunk_output = self.forward_chunk(&input_chunk.into_dyn())?;
617
618 #[cfg(not(feature = "memory_efficient"))]
619 let chunk_output = self.forward_chunk_fallback(&input_chunk.into_dyn())?;
620
621 output
623 .slice_mut(ndarray::s![start_idx..end_idx, ..])
624 .assign(&chunk_output);
625 }
626
627 Ok(output.into_dyn())
628 }
629
630 #[cfg(feature = "memory_efficient")]
632 fn forward_chunk(&self, input_chunk: &ArrayView<f32, IxDyn>) -> Result<ndarray::Array2<f32>> {
633 let chunk_shape = input_chunk.shape();
634 let chunk_batch_size = chunk_shape[0];
635 let output_size = self.bias.len();
636
637 let result = input_chunk.to_owned();
656
657 let mut output = ndarray::Array2::zeros((chunk_batch_size, output_size));
659 for (mut row, bias_val) in output.rows_mut().into_iter().zip(self.bias.iter().cycle()) {
660 for (out_val, result_val) in row.iter_mut().zip(result.iter()) {
661 *out_val = result_val + bias_val;
662 }
663 }
664
665 Ok(output)
666 }
667
668 #[cfg(not(feature = "memory_efficient"))]
670 fn forward_chunk_fallback(
671 &self,
672 input_chunk: &ArrayView<f32, IxDyn>,
673 ) -> Result<ndarray::Array2<f32>> {
674 let input_2d = input_chunk
676 .view()
677 .into_dimensionality::<ndarray::Ix2>()
678 .map_err(|e| {
679 NeuralError::DimensionMismatch(format!("Failed to convert to 2D: {}", e))
680 })?;
681
682 let (_chunk_batch_size, input_size) = input_2d.dim();
684 let output_size = self.bias.len();
685 let weights_2d = ndarray::Array2::<f32>::zeros((input_size, output_size));
686
687 let mut result =
690 ndarray::Array2::<f32>::zeros((input_2d.shape()[0], weights_2d.shape()[1]));
691 for i in 0..input_2d.shape()[0] {
692 for j in 0..weights_2d.shape()[1] {
693 let mut sum = 0.0f32;
694 for k in 0..input_2d.shape()[1] {
695 sum += input_2d[[i, k]] * weights_2d[[k, j]];
696 }
697 result[[i, j]] = sum;
698 }
699 }
700
701 for mut row in result.rows_mut() {
703 for (out_val, bias_val) in row.iter_mut().zip(self.bias.iter()) {
704 *out_val += bias_val;
705 }
706 }
707
708 Ok(result)
709 }
710
711 #[cfg(feature = "cache")]
720 pub fn cache_activation(&mut self, key: String, activation: ArrayD<f32>) {
721 self.activation_cache.insert(key, activation);
722 }
723
724 #[cfg(feature = "cache")]
726 pub fn get_cached_activation(&mut self, key: &str) -> Option<ArrayD<f32>> {
727 self.activation_cache.get(&key.to_string())
728 }
729}
730
731#[cfg(feature = "memory_efficient")]
733#[allow(dead_code)]
734struct ChunkForwardProcessor<'a> {
735 weights: &'a ArrayD<f32>,
736 bias: &'a ndarray::Array1<f32>,
737}
738
739#[cfg(test)]
756mod tests {
757 use super::*;
758 use ndarray::Array2;
759
760 #[test]
761 fn test_memory_pool() {
762 let mut pool = MemoryPool::<f32>::new(10); let tensor1 = pool.allocate(&[100, 100]);
766 assert_eq!(tensor1.shape(), [100, 100]);
767
768 pool.deallocate(tensor1);
770
771 let tensor2 = pool.allocate(&[100, 100]);
773 assert_eq!(tensor2.shape(), [100, 100]);
774
775 let stats = pool.get_pool_stats();
776 assert_eq!(stats.unique_shapes, 1);
777 }
778
779 #[test]
780 fn test_gradient_checkpointing() {
781 let mut checkpointing = GradientCheckpointing::<f64>::new(100.0); checkpointing.add_checkpoint_layer("layer1".to_string());
784
785 let activation = Array2::from_elem((10, 10), 1.0).into_dyn();
786 checkpointing
787 .store_checkpoint("layer1", activation)
788 .unwrap();
789
790 assert!(checkpointing.get_checkpoint("layer1").is_some());
791
792 checkpointing.clear_checkpoints();
793 assert!(checkpointing.get_checkpoint("layer1").is_none());
794 }
795
796 #[test]
797 fn test_in_place_operations() {
798 let mut array = Array2::from_elem((3, 3), -1.0).into_dyn();
799
800 InPlaceOperations::relu_inplace(&mut array);
802 for &val in array.iter() {
803 assert!(val >= 0.0);
804 }
805
806 InPlaceOperations::scale_inplace(&mut array, 2.0);
808 for &val in array.iter() {
809 assert_eq!(val, 0.0); }
811 }
812
813 #[test]
814 fn test_memory_aware_batch_processor() {
815 let mut processor = MemoryAwareBatchProcessor::<f32>::new(100, 50.0, 10);
816
817 let input = Array2::from_elem((20, 5), 1.0).into_dyn();
818
819 let results = processor
820 .process_batches(&input, |batch| Ok(batch.to_owned()))
821 .unwrap();
822
823 assert!(!results.is_empty());
824
825 let stats = processor.get_stats();
826 assert!(stats.max_batch_size > 0);
827 }
828
829 #[test]
830 fn test_memory_usage_tracking() {
831 let mut usage = MemoryUsage::new();
832
833 usage.allocate(1024 * 1024); assert_eq!(usage.current_mb(), 1.0);
835 assert_eq!(usage.peak_mb(), 1.0);
836 assert_eq!(usage.active_allocations, 1);
837
838 usage.allocate(2 * 1024 * 1024); assert_eq!(usage.current_mb(), 3.0);
840 assert_eq!(usage.peak_mb(), 3.0);
841 assert_eq!(usage.active_allocations, 2);
842
843 usage.deallocate(1024 * 1024); assert_eq!(usage.current_mb(), 2.0);
845 assert_eq!(usage.peak_mb(), 3.0); assert_eq!(usage.active_allocations, 1);
847 }
848}