scirs2_neural/
memory_efficient.rs

1//! Memory-efficient implementations for neural networks
2//!
3//! This module provides memory optimization techniques including:
4//! - Gradient checkpointing for reduced memory usage during training
5//! - In-place operations to minimize memory allocations
6//! - Memory pool management for efficient tensor allocation
7//! - Memory-aware batch processing
8//! - Lazy evaluation and computation graphs
9
10use crate::error::{NeuralError, Result};
11use crate::layers::Layer;
12use ndarray::{Array, ArrayD, ArrayView, IxDyn};
13use num_traits::Float;
14use std::collections::{HashMap, VecDeque};
15use std::fmt::Debug;
16use std::sync::{Arc, Mutex, RwLock};
17
18#[cfg(feature = "memory_efficient")]
19// FIXME: chunk_wise_op usage commented out due to signature mismatch
20// use scirs2_core::memory_efficient::chunk_wise_op;
21#[cfg(feature = "memory_management")]
22use scirs2_core::memory_efficient::BufferPool;
23
24// FIXME: MemoryManager not available in current scirs2-core
25// #[cfg(feature = "memory_management")]
26// use scirs2_core::resource::memory::{AllocationStrategy, MemoryManager};
27
28// FIXME: ChunkProcessor not available as trait in current scirs2-core
29// #[cfg(feature = "memory_management")]
30// use scirs2_core::ChunkProcessor;
31
32// Note: These imports may need to be adjusted based on available types in scirs2_core
33// #[cfg(feature = "memory_management")]
34// use scirs2_core::memory_management::{
35//     AllocationStrategy, BufferPool, MemoryManager, MemoryMetrics,
36// };
37
38#[cfg(feature = "cache")]
39use scirs2_core::cache::{CacheBuilder, TTLSizedCache};
40
41/// Memory usage tracking and reporting
42#[derive(Debug, Clone)]
43pub struct MemoryUsage {
44    /// Current memory usage in bytes
45    pub current_bytes: usize,
46    /// Peak memory usage in bytes
47    pub peak_bytes: usize,
48    /// Number of active allocations
49    pub active_allocations: usize,
50    /// Total allocations made
51    pub total_allocations: usize,
52}
53
54impl MemoryUsage {
55    /// Create a new memory usage tracker
56    pub fn new() -> Self {
57        Self {
58            current_bytes: 0,
59            peak_bytes: 0,
60            active_allocations: 0,
61            total_allocations: 0,
62        }
63    }
64
65    /// Update memory usage statistics
66    pub fn allocate(&mut self, bytes: usize) {
67        self.current_bytes += bytes;
68        self.peak_bytes = self.peak_bytes.max(self.current_bytes);
69        self.active_allocations += 1;
70        self.total_allocations += 1;
71    }
72
73    /// Record memory deallocation
74    pub fn deallocate(&mut self, bytes: usize) {
75        self.current_bytes = self.current_bytes.saturating_sub(bytes);
76        self.active_allocations = self.active_allocations.saturating_sub(1);
77    }
78
79    /// Get memory usage in MB
80    pub fn current_mb(&self) -> f64 {
81        self.current_bytes as f64 / (1024.0 * 1024.0)
82    }
83
84    /// Get peak memory usage in MB
85    pub fn peak_mb(&self) -> f64 {
86        self.peak_bytes as f64 / (1024.0 * 1024.0)
87    }
88}
89
90impl Default for MemoryUsage {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96/// Memory pool for efficient tensor allocation and reuse
97pub struct MemoryPool<F: Float + Debug> {
98    /// Available tensors organized by size
99    available_tensors: HashMap<Vec<usize>, VecDeque<ArrayD<F>>>,
100    /// Memory usage tracking
101    usage: Arc<Mutex<MemoryUsage>>,
102    /// Maximum pool size in bytes
103    max_pool_size: usize,
104    /// Current pool size in bytes
105    current_pool_size: usize,
106}
107
108impl<F: Float + Debug + Clone + 'static> MemoryPool<F> {
109    /// Create a new memory pool
110    pub fn new(max_pool_size_mb: usize) -> Self {
111        Self {
112            available_tensors: HashMap::new(),
113            usage: Arc::new(Mutex::new(MemoryUsage::new())),
114            max_pool_size: max_pool_size_mb * 1024 * 1024,
115            current_pool_size: 0,
116        }
117    }
118
119    /// Allocate or reuse a tensor with the given shape
120    pub fn allocate(&mut self, shape: &[usize]) -> ArrayD<F> {
121        let shape_vec = shape.to_vec();
122
123        // Try to reuse an existing tensor
124        if let Some(tensors) = self.available_tensors.get_mut(&shape_vec) {
125            if let Some(mut tensor) = tensors.pop_front() {
126                // Zero out the tensor for reuse
127                tensor.fill(F::zero());
128
129                // Update memory usage
130                if let Ok(mut usage) = self.usage.lock() {
131                    let bytes = Self::calculate_bytes(&shape_vec);
132                    usage.allocate(bytes);
133                }
134
135                return tensor;
136            }
137        }
138
139        // Create a new tensor if none available
140        let tensor = Array::zeros(IxDyn(shape));
141
142        // Update memory usage
143        if let Ok(mut usage) = self.usage.lock() {
144            let bytes = Self::calculate_bytes(&shape_vec);
145            usage.allocate(bytes);
146        }
147
148        tensor
149    }
150
151    /// Return a tensor to the pool for reuse
152    pub fn deallocate(&mut self, tensor: ArrayD<F>) {
153        let shape = tensor.shape().to_vec();
154        let bytes = Self::calculate_bytes(&shape);
155
156        // Check if we have space in the pool
157        if self.current_pool_size + bytes <= self.max_pool_size {
158            self.available_tensors
159                .entry(shape)
160                .or_default()
161                .push_back(tensor);
162            self.current_pool_size += bytes;
163        }
164
165        // Update memory usage
166        if let Ok(mut usage) = self.usage.lock() {
167            usage.deallocate(bytes);
168        }
169    }
170
171    /// Get current memory usage
172    pub fn get_usage(&self) -> MemoryUsage {
173        self.usage
174            .lock()
175            .unwrap_or_else(|poisoned| poisoned.into_inner())
176            .clone()
177    }
178
179    /// Clear the memory pool
180    pub fn clear(&mut self) {
181        self.available_tensors.clear();
182        self.current_pool_size = 0;
183    }
184
185    /// Calculate memory usage for a tensor shape (assuming F is f32/f64)
186    fn calculate_bytes(shape: &[usize]) -> usize {
187        let elements: usize = shape.iter().product();
188        elements * std::mem::size_of::<F>()
189    }
190
191    /// Get pool statistics
192    pub fn get_pool_stats(&self) -> PoolStatistics {
193        let total_tensors: usize = self.available_tensors.values().map(|v| v.len()).sum();
194        let unique_shapes = self.available_tensors.len();
195
196        PoolStatistics {
197            total_cached_tensors: total_tensors,
198            unique_shapes,
199            current_pool_size_mb: self.current_pool_size as f64 / (1024.0 * 1024.0),
200            max_pool_size_mb: self.max_pool_size as f64 / (1024.0 * 1024.0),
201        }
202    }
203}
204
205/// Statistics about the memory pool
206#[derive(Debug, Clone)]
207pub struct PoolStatistics {
208    /// Number of tensors currently cached
209    pub total_cached_tensors: usize,
210    /// Number of unique tensor shapes in the pool
211    pub unique_shapes: usize,
212    /// Current pool size in MB
213    pub current_pool_size_mb: f64,
214    /// Maximum pool size in MB
215    pub max_pool_size_mb: f64,
216}
217
218/// Gradient checkpointing implementation for memory-efficient training
219pub struct GradientCheckpointing<F: Float + Debug> {
220    /// Checkpoint layers - only these will store activations
221    checkpoint_layers: Vec<String>,
222    /// Stored activations at checkpoint layers
223    checkpoints: HashMap<String, ArrayD<F>>,
224    /// Memory usage threshold for automatic checkpointing
225    memory_threshold_mb: f64,
226    /// Current memory usage tracker
227    memory_usage: Arc<RwLock<MemoryUsage>>,
228}
229
230impl<F: Float + Debug + Clone + 'static + ndarray::ScalarOperand> GradientCheckpointing<F> {
231    /// Create a new gradient checkpointing manager
232    pub fn new(memory_threshold_mb: f64) -> Self {
233        Self {
234            checkpoint_layers: Vec::new(),
235            checkpoints: HashMap::new(),
236            memory_threshold_mb,
237            memory_usage: Arc::new(RwLock::new(MemoryUsage::new())),
238        }
239    }
240
241    /// Add a layer as a checkpoint point
242    pub fn add_checkpoint_layer(&mut self, layer_name: String) {
243        self.checkpoint_layers.push(layer_name);
244    }
245
246    /// Store activation at a checkpoint
247    pub fn store_checkpoint(&mut self, layer_name: &str, activation: ArrayD<F>) -> Result<()> {
248        if self.checkpoint_layers.contains(&layer_name.to_string()) {
249            // Calculate memory usage
250            let bytes = activation.len() * std::mem::size_of::<F>();
251
252            if let Ok(mut usage) = self.memory_usage.write() {
253                usage.allocate(bytes);
254
255                // Check if we're exceeding memory threshold
256                if usage.current_mb() > self.memory_threshold_mb {
257                    return Err(NeuralError::ComputationError(format!(
258                        "Memory threshold exceeded: {:.2}MB > {:.2}MB",
259                        usage.current_mb(),
260                        self.memory_threshold_mb
261                    )));
262                }
263            }
264
265            self.checkpoints.insert(layer_name.to_string(), activation);
266        }
267        Ok(())
268    }
269
270    /// Retrieve activation from checkpoint
271    pub fn get_checkpoint(&self, layer_name: &str) -> Option<&ArrayD<F>> {
272        self.checkpoints.get(layer_name)
273    }
274
275    /// Clear checkpoints to free memory
276    pub fn clear_checkpoints(&mut self) {
277        let total_bytes: usize = self
278            .checkpoints
279            .values()
280            .map(|arr| arr.len() * std::mem::size_of::<F>())
281            .sum();
282
283        self.checkpoints.clear();
284
285        if let Ok(mut usage) = self.memory_usage.write() {
286            usage.deallocate(total_bytes);
287        }
288    }
289
290    /// Get current memory usage
291    pub fn get_memory_usage(&self) -> MemoryUsage {
292        self.memory_usage
293            .read()
294            .map(|usage| usage.clone())
295            .unwrap_or_default()
296    }
297
298    /// Recompute forward pass from last checkpoint
299    pub fn recompute_from_checkpoint<L>(
300        &self,
301        layers: &[L],
302        start_layer: &str,
303        _target_layer: &str,
304        _input: &ArrayD<F>,
305    ) -> Result<ArrayD<F>>
306    where
307        L: Layer<F>,
308    {
309        // Find the checkpoint closest to target_layer
310        let checkpoint_activation = self.get_checkpoint(start_layer).ok_or_else(|| {
311            NeuralError::ComputationError(format!("No checkpoint found for layer: {}", start_layer))
312        })?;
313
314        // Recompute forward pass from checkpoint to target
315        let mut current_activation = checkpoint_activation.clone();
316
317        // This is a simplified implementation
318        // In practice, you'd need layer ordering and proper forward pass logic
319        for layer in layers {
320            current_activation = layer.forward(&current_activation)?;
321        }
322
323        Ok(current_activation)
324    }
325}
326
327/// In-place operations manager for minimizing memory allocations
328pub struct InPlaceOperations;
329
330impl InPlaceOperations {
331    /// In-place ReLU activation
332    pub fn relu_inplace<F: Float + Debug>(array: &mut ArrayD<F>) {
333        array.mapv_inplace(|x| x.max(F::zero()));
334    }
335
336    /// In-place sigmoid activation
337    pub fn sigmoid_inplace<F: Float + Debug>(array: &mut ArrayD<F>) {
338        array.mapv_inplace(|x| F::one() / (F::one() + (-x).exp()));
339    }
340
341    /// In-place tanh activation
342    pub fn tanh_inplace<F: Float + Debug>(array: &mut ArrayD<F>) {
343        array.mapv_inplace(|x| x.tanh());
344    }
345
346    /// In-place addition
347    pub fn add_inplace<F: Float + Debug>(target: &mut ArrayD<F>, source: &ArrayD<F>) -> Result<()> {
348        if target.shape() != source.shape() {
349            return Err(NeuralError::ShapeMismatch(
350                "Arrays must have the same shape for in-place addition".to_string(),
351            ));
352        }
353
354        for (t, &s) in target.iter_mut().zip(source.iter()) {
355            *t = *t + s;
356        }
357
358        Ok(())
359    }
360
361    /// In-place scalar multiplication
362    pub fn scale_inplace<F: Float + Debug>(array: &mut ArrayD<F>, factor: F) {
363        array.mapv_inplace(|x| x * factor);
364    }
365
366    /// In-place normalization (subtract mean, divide by std)
367    pub fn normalize_inplace<F: Float + Debug + Clone + num_traits::FromPrimitive>(
368        array: &mut ArrayD<F>,
369    ) -> Result<()> {
370        let mean = array.mean().unwrap_or(F::zero());
371        let variance = array
372            .mapv(|x| (x - mean) * (x - mean))
373            .mean()
374            .unwrap_or(F::zero());
375        let std_dev = variance.sqrt();
376
377        if std_dev == F::zero() {
378            return Ok(()); // Avoid division by zero
379        }
380
381        array.mapv_inplace(|x| (x - mean) / std_dev);
382        Ok(())
383    }
384
385    /// In-place dropout (sets elements to zero based on probability)
386    pub fn dropout_inplace<F: Float + Debug>(
387        array: &mut ArrayD<F>,
388        dropout_rate: f64,
389        training: bool,
390    ) -> Result<()> {
391        if !training {
392            return Ok(());
393        }
394
395        let keep_prob = 1.0 - dropout_rate;
396        let scale_factor = F::from(1.0 / keep_prob).unwrap();
397
398        for element in array.iter_mut() {
399            if rand::random::<f64>() < dropout_rate {
400                *element = F::zero();
401            } else {
402                *element = *element * scale_factor;
403            }
404        }
405
406        Ok(())
407    }
408}
409
410/// Memory-aware batch processor for handling large datasets
411pub struct MemoryAwareBatchProcessor<F: Float + Debug> {
412    /// Maximum batch size based on available memory
413    max_batch_size: usize,
414    /// Memory pool for tensor reuse
415    memory_pool: MemoryPool<F>,
416    /// Current memory usage threshold
417    memory_threshold_mb: f64,
418}
419
420impl<F: Float + Debug + Clone + 'static> MemoryAwareBatchProcessor<F> {
421    /// Create a new memory-aware batch processor
422    pub fn new(max_memory_mb: usize, memory_threshold_mb: f64, pool_size_mb: usize) -> Self {
423        Self {
424            max_batch_size: Self::calculate_max_batch_size(max_memory_mb),
425            memory_pool: MemoryPool::new(pool_size_mb),
426            memory_threshold_mb,
427        }
428    }
429
430    /// Process batches with automatic size adjustment based on memory usage
431    pub fn process_batches<ProcessFn>(
432        &mut self,
433        input: &ArrayD<F>,
434        mut process_fn: ProcessFn,
435    ) -> Result<Vec<ArrayD<F>>>
436    where
437        ProcessFn: FnMut(&ArrayView<F, IxDyn>) -> Result<ArrayD<F>>,
438    {
439        let total_samples = input.shape()[0];
440        let mut results = Vec::new();
441        let mut start_idx = 0;
442
443        while start_idx < total_samples {
444            // Determine batch size based on current memory usage
445            let current_usage = self.memory_pool.get_usage();
446            let available_memory_mb = self.memory_threshold_mb - current_usage.current_mb();
447
448            let batch_size = if available_memory_mb < 100.0 {
449                // Low memory - use smaller batches
450                (self.max_batch_size / 4).max(1)
451            } else if available_memory_mb < 200.0 {
452                // Medium memory - use half batch size
453                self.max_batch_size / 2
454            } else {
455                // Plenty of memory - use full batch size
456                self.max_batch_size
457            };
458
459            let end_idx = (start_idx + batch_size).min(total_samples);
460            let batch = input.slice(ndarray::s![start_idx..end_idx, ..]).into_dyn();
461
462            // Process the batch
463            let result = process_fn(&batch)?;
464            results.push(result);
465
466            start_idx = end_idx;
467
468            // Optional: Force garbage collection if memory usage is high
469            if current_usage.current_mb() > self.memory_threshold_mb * 0.8 {
470                self.memory_pool.clear();
471            }
472        }
473
474        Ok(results)
475    }
476
477    /// Calculate maximum batch size based on available memory
478    fn calculate_max_batch_size(max_memory_mb: usize) -> usize {
479        // Heuristic: assume each sample uses ~1KB on average
480        let max_memory_bytes = max_memory_mb * 1024 * 1024;
481        let bytes_per_sample = 1024; // 1KB per sample estimate
482        (max_memory_bytes / bytes_per_sample).max(1)
483    }
484
485    /// Get current batch processor statistics
486    pub fn get_stats(&self) -> BatchProcessorStats {
487        let usage = self.memory_pool.get_usage();
488        let pool_stats = self.memory_pool.get_pool_stats();
489
490        BatchProcessorStats {
491            max_batch_size: self.max_batch_size,
492            current_memory_mb: usage.current_mb(),
493            peak_memory_mb: usage.peak_mb(),
494            memory_threshold_mb: self.memory_threshold_mb,
495            pool_stats,
496        }
497    }
498}
499
500/// Statistics for the batch processor
501#[derive(Debug, Clone)]
502pub struct BatchProcessorStats {
503    /// Maximum batch size
504    pub max_batch_size: usize,
505    /// Current memory usage in MB
506    pub current_memory_mb: f64,
507    /// Peak memory usage in MB
508    pub peak_memory_mb: f64,
509    /// Memory threshold in MB
510    pub memory_threshold_mb: f64,
511    /// Memory pool statistics
512    pub pool_stats: PoolStatistics,
513}
514
515/// Memory-efficient neural network layer that processes data in chunks
516pub struct MemoryEfficientLayer {
517    /// Weight matrix stored in memory-efficient format
518    #[cfg(feature = "memory_efficient")]
519    #[allow(dead_code)]
520    weights: ArrayD<f32>,
521
522    /// Bias vector
523    bias: ndarray::Array1<f32>,
524
525    /// Chunk size for processing
526    chunk_size: usize,
527
528    /// Memory manager for efficient allocation
529    // FIXME: MemoryManager not available in current scirs2-core
530    // #[cfg(feature = "memory_management")]
531    // memory_manager: Arc<MemoryManager>,
532
533    /// Buffer pool for temporary allocations
534    #[cfg(feature = "memory_management")]
535    #[allow(dead_code)]
536    buffer_pool: Arc<BufferPool>,
537
538    /// Cache for activations (useful during training)
539    #[cfg(feature = "cache")]
540    activation_cache: TTLSizedCache<String, ArrayD<f32>>,
541}
542
543impl MemoryEfficientLayer {
544    /// Create a new memory-efficient layer
545    pub fn new(input_size: usize, output_size: usize, chunk_size: Option<usize>) -> Result<Self> {
546        let _weights_shape = [input_size, output_size];
547        let default_chunk_size = chunk_size.unwrap_or(1024);
548
549        #[cfg(feature = "memory_efficient")]
550        let weights = ArrayD::zeros(IxDyn(&_weights_shape));
551
552        let bias = ndarray::Array1::zeros(output_size);
553
554        // FIXME: MemoryManager not available in current scirs2-core
555        // #[cfg(feature = "memory_management")]
556        // let memory_manager = Arc::new(MemoryManager::new(
557        //     AllocationStrategy::FirstFit,
558        //     1024 * 1024 * 100,
559        // )); // 100MB
560
561        #[cfg(feature = "memory_management")]
562        let buffer_pool = Arc::new(
563            BufferPool::new(
564                1000,                             // pool_size
565                default_chunk_size * output_size, // buffer_size
566                false,                            // numa_aware
567                64,                               // alignment
568            )
569            .unwrap(),
570        );
571
572        #[cfg(feature = "cache")]
573        let activation_cache = CacheBuilder::new()
574            .with_size(100)
575            .with_ttl(300)
576            .build_sized_cache();
577
578        Ok(Self {
579            #[cfg(feature = "memory_efficient")]
580            weights,
581            bias,
582            chunk_size: default_chunk_size,
583            // FIXME: MemoryManager not available in current scirs2-core
584            // #[cfg(feature = "memory_management")]
585            // memory_manager,
586            #[cfg(feature = "memory_management")]
587            buffer_pool,
588            #[cfg(feature = "cache")]
589            activation_cache,
590        })
591    }
592
593    /// Forward pass with memory-efficient chunk processing
594    pub fn forward(&self, input: &ArrayD<f32>) -> Result<ArrayD<f32>> {
595        let input_shape = input.shape();
596        let batch_size = input_shape[0];
597        let _input_size = input_shape[1];
598        let output_size = self.bias.len();
599
600        // Create output array
601        let mut output = Array::zeros((batch_size, output_size));
602
603        // Process in chunks to minimize memory usage
604        let chunks = batch_size.div_ceil(self.chunk_size);
605
606        for chunk_idx in 0..chunks {
607            let start_idx = chunk_idx * self.chunk_size;
608            let end_idx = std::cmp::min(start_idx + self.chunk_size, batch_size);
609            let _chunk_batch_size = end_idx - start_idx;
610
611            // Extract input chunk
612            let input_chunk = input.slice(ndarray::s![start_idx..end_idx, ..]);
613
614            // Compute matrix multiplication for this chunk
615            #[cfg(feature = "memory_efficient")]
616            let chunk_output = self.forward_chunk(&input_chunk.into_dyn())?;
617
618            #[cfg(not(feature = "memory_efficient"))]
619            let chunk_output = self.forward_chunk_fallback(&input_chunk.into_dyn())?;
620
621            // Copy result to output array
622            output
623                .slice_mut(ndarray::s![start_idx..end_idx, ..])
624                .assign(&chunk_output);
625        }
626
627        Ok(output.into_dyn())
628    }
629
630    /// Memory-efficient forward pass for a single chunk
631    #[cfg(feature = "memory_efficient")]
632    fn forward_chunk(&self, input_chunk: &ArrayView<f32, IxDyn>) -> Result<ndarray::Array2<f32>> {
633        let chunk_shape = input_chunk.shape();
634        let chunk_batch_size = chunk_shape[0];
635        let output_size = self.bias.len();
636
637        // Use chunk-wise operation for memory efficiency
638        // FIXME: processor not used due to commented out chunk_wise_op
639        // let processor = ChunkForwardProcessor {
640        //     weights: &self.weights,
641        //     bias: &self.bias,
642        // };
643
644        // FIXME: chunk_wise_op signature mismatch - needs refactoring
645        // let result = chunk_wise_op(
646        //     &input_chunk.to_owned(),
647        //     1024, // Processing chunk size
648        //     &processor,
649        // )
650        // .map_err(|e| {
651        //     NeuralError::ComputationError(format!("Chunk-wise operation failed: {:?}", e))
652        // })?;
653
654        // Temporary fallback - simple matrix multiplication
655        let result = input_chunk.to_owned();
656
657        // Add bias
658        let mut output = ndarray::Array2::zeros((chunk_batch_size, output_size));
659        for (mut row, bias_val) in output.rows_mut().into_iter().zip(self.bias.iter().cycle()) {
660            for (out_val, result_val) in row.iter_mut().zip(result.iter()) {
661                *out_val = result_val + bias_val;
662            }
663        }
664
665        Ok(output)
666    }
667
668    /// Fallback implementation when memory_efficient feature is not available
669    #[cfg(not(feature = "memory_efficient"))]
670    fn forward_chunk_fallback(
671        &self,
672        input_chunk: &ArrayView<f32, IxDyn>,
673    ) -> Result<ndarray::Array2<f32>> {
674        // Simple fallback using regular ndarray operations
675        let input_2d = input_chunk
676            .view()
677            .into_dimensionality::<ndarray::Ix2>()
678            .map_err(|e| {
679                NeuralError::DimensionMismatch(format!("Failed to convert to 2D: {}", e))
680            })?;
681
682        // For fallback, create a simple weight matrix
683        let (_chunk_batch_size, input_size) = input_2d.dim();
684        let output_size = self.bias.len();
685        let weights_2d = ndarray::Array2::<f32>::zeros((input_size, output_size));
686
687        // TODO: Replace with scirs2-core matrix multiplication when available
688        // For now, using manual matrix multiplication
689        let mut result =
690            ndarray::Array2::<f32>::zeros((input_2d.shape()[0], weights_2d.shape()[1]));
691        for i in 0..input_2d.shape()[0] {
692            for j in 0..weights_2d.shape()[1] {
693                let mut sum = 0.0f32;
694                for k in 0..input_2d.shape()[1] {
695                    sum += input_2d[[i, k]] * weights_2d[[k, j]];
696                }
697                result[[i, j]] = sum;
698            }
699        }
700
701        // Add bias
702        for mut row in result.rows_mut() {
703            for (out_val, bias_val) in row.iter_mut().zip(self.bias.iter()) {
704                *out_val += bias_val;
705            }
706        }
707
708        Ok(result)
709    }
710
711    // FIXME: MemoryManager not available in current scirs2-core
712    // /// Get memory usage statistics
713    // #[cfg(feature = "memory_management")]
714    // pub fn get_memory_stats(&self) -> MemoryMetrics {
715    //     self.memory_manager.get_metrics()
716    // }
717
718    /// Cache activation for reuse during training
719    #[cfg(feature = "cache")]
720    pub fn cache_activation(&mut self, key: String, activation: ArrayD<f32>) {
721        self.activation_cache.insert(key, activation);
722    }
723
724    /// Retrieve cached activation
725    #[cfg(feature = "cache")]
726    pub fn get_cached_activation(&mut self, key: &str) -> Option<ArrayD<f32>> {
727        self.activation_cache.get(&key.to_string())
728    }
729}
730
731/// Processor for chunk-wise forward operations
732#[cfg(feature = "memory_efficient")]
733#[allow(dead_code)]
734struct ChunkForwardProcessor<'a> {
735    weights: &'a ArrayD<f32>,
736    bias: &'a ndarray::Array1<f32>,
737}
738
739// FIXME: ChunkProcessor trait not available in current scirs2-core
740// #[cfg(feature = "memory_efficient")]
741// impl<'a> ChunkProcessor<f32> for ChunkForwardProcessor<'a> {
742//     type Output = ArrayD<f32>;
743//     type Error = crate::error::NeuralError;
744//
745//     fn process_chunk(
746//         &self,
747//         chunk: ArrayView<f32, IxDyn>,
748//     ) -> std::result::Result<ArrayD<f32>, crate::error::NeuralError> {
749//         // Simplified processing for demonstration
750//         // In a real implementation, this would use the memory-efficient weights
751//         Ok(chunk.to_owned())
752//     }
753// }
754
755#[cfg(test)]
756mod tests {
757    use super::*;
758    use ndarray::Array2;
759
760    #[test]
761    fn test_memory_pool() {
762        let mut pool = MemoryPool::<f32>::new(10); // 10MB max
763
764        // Allocate a tensor
765        let tensor1 = pool.allocate(&[100, 100]);
766        assert_eq!(tensor1.shape(), [100, 100]);
767
768        // Return it to the pool
769        pool.deallocate(tensor1);
770
771        // Allocate again - should reuse
772        let tensor2 = pool.allocate(&[100, 100]);
773        assert_eq!(tensor2.shape(), [100, 100]);
774
775        let stats = pool.get_pool_stats();
776        assert_eq!(stats.unique_shapes, 1);
777    }
778
779    #[test]
780    fn test_gradient_checkpointing() {
781        let mut checkpointing = GradientCheckpointing::<f64>::new(100.0); // 100MB threshold
782
783        checkpointing.add_checkpoint_layer("layer1".to_string());
784
785        let activation = Array2::from_elem((10, 10), 1.0).into_dyn();
786        checkpointing
787            .store_checkpoint("layer1", activation)
788            .unwrap();
789
790        assert!(checkpointing.get_checkpoint("layer1").is_some());
791
792        checkpointing.clear_checkpoints();
793        assert!(checkpointing.get_checkpoint("layer1").is_none());
794    }
795
796    #[test]
797    fn test_in_place_operations() {
798        let mut array = Array2::from_elem((3, 3), -1.0).into_dyn();
799
800        // Test in-place ReLU
801        InPlaceOperations::relu_inplace(&mut array);
802        for &val in array.iter() {
803            assert!(val >= 0.0);
804        }
805
806        // Test in-place scaling
807        InPlaceOperations::scale_inplace(&mut array, 2.0);
808        for &val in array.iter() {
809            assert_eq!(val, 0.0); // Was negative, became 0 after ReLU, then scaled
810        }
811    }
812
813    #[test]
814    fn test_memory_aware_batch_processor() {
815        let mut processor = MemoryAwareBatchProcessor::<f32>::new(100, 50.0, 10);
816
817        let input = Array2::from_elem((20, 5), 1.0).into_dyn();
818
819        let results = processor
820            .process_batches(&input, |batch| Ok(batch.to_owned()))
821            .unwrap();
822
823        assert!(!results.is_empty());
824
825        let stats = processor.get_stats();
826        assert!(stats.max_batch_size > 0);
827    }
828
829    #[test]
830    fn test_memory_usage_tracking() {
831        let mut usage = MemoryUsage::new();
832
833        usage.allocate(1024 * 1024); // 1MB
834        assert_eq!(usage.current_mb(), 1.0);
835        assert_eq!(usage.peak_mb(), 1.0);
836        assert_eq!(usage.active_allocations, 1);
837
838        usage.allocate(2 * 1024 * 1024); // 2MB more
839        assert_eq!(usage.current_mb(), 3.0);
840        assert_eq!(usage.peak_mb(), 3.0);
841        assert_eq!(usage.active_allocations, 2);
842
843        usage.deallocate(1024 * 1024); // Release 1MB
844        assert_eq!(usage.current_mb(), 2.0);
845        assert_eq!(usage.peak_mb(), 3.0); // Peak should remain
846        assert_eq!(usage.active_allocations, 1);
847    }
848}