Skip to main content

tenflowers_autograd/
efficient_memory.rs

1//! Memory-Efficient Gradient Computation
2//!
3//! This module provides memory-efficient implementations of gradient computation
4//! techniques including gradient checkpointing, memory pooling, and lazy evaluation.
5
6use scirs2_core::numeric::Float;
7use std::collections::{HashMap, VecDeque};
8use std::sync::{Arc, Mutex};
9use tenflowers_core::{Result, Tensor, TensorError};
10
11/// Memory pool for reusing tensor allocations during gradient computation
12pub struct GradientMemoryPool<T> {
13    available_tensors: HashMap<Vec<usize>, VecDeque<Tensor<T>>>,
14    max_pool_size: usize,
15    total_allocated: usize,
16}
17
18impl<T> GradientMemoryPool<T>
19where
20    T: Clone + Default + Send + Sync + 'static + scirs2_core::num_traits::Zero,
21{
22    /// Create a new gradient memory pool
23    pub fn new(max_pool_size: usize) -> Self {
24        Self {
25            available_tensors: HashMap::new(),
26            max_pool_size,
27            total_allocated: 0,
28        }
29    }
30
31    /// Get a tensor from the pool or create a new one
32    pub fn get_tensor(&mut self, shape: &[usize]) -> Tensor<T> {
33        let shape_vec = shape.to_vec();
34
35        if let Some(tensor_queue) = self.available_tensors.get_mut(&shape_vec) {
36            if let Some(tensor) = tensor_queue.pop_front() {
37                return tensor;
38            }
39        }
40
41        // Create new tensor if none available in pool
42        self.total_allocated += 1;
43        Tensor::zeros(shape)
44    }
45
46    /// Return a tensor to the pool for reuse
47    pub fn return_tensor(&mut self, tensor: Tensor<T>) {
48        let shape = tensor.shape().dims().to_vec();
49
50        let tensor_queue = self.available_tensors.entry(shape).or_default();
51
52        if tensor_queue.len() < self.max_pool_size {
53            tensor_queue.push_back(tensor);
54        }
55        // If pool is full, tensor will be dropped
56    }
57
58    /// Get pool statistics
59    pub fn get_stats(&self) -> MemoryPoolStats {
60        let total_pooled: usize = self
61            .available_tensors
62            .values()
63            .map(|queue| queue.len())
64            .sum();
65
66        MemoryPoolStats {
67            total_allocated: self.total_allocated,
68            total_pooled,
69            pool_hit_ratio: if self.total_allocated > 0 {
70                total_pooled as f64 / self.total_allocated as f64
71            } else {
72                0.0
73            },
74        }
75    }
76
77    /// Clear the memory pool
78    pub fn clear(&mut self) {
79        self.available_tensors.clear();
80        self.total_allocated = 0;
81    }
82}
83
84/// Statistics for memory pool usage
85#[derive(Debug, Clone)]
86pub struct MemoryPoolStats {
87    pub total_allocated: usize,
88    pub total_pooled: usize,
89    pub pool_hit_ratio: f64,
90}
91
92/// Gradient checkpointing manager for memory-efficient backpropagation
93pub struct GradientCheckpointer<T> {
94    checkpoints: HashMap<String, CheckpointData<T>>,
95    memory_budget: usize,
96    current_memory_usage: usize,
97}
98
99/// Data stored at a checkpoint
100#[derive(Clone)]
101struct CheckpointData<T> {
102    tensor: Tensor<T>,
103    computation_cost: f64,
104    memory_size: usize,
105    last_accessed: std::time::Instant,
106}
107
108impl<T> GradientCheckpointer<T>
109where
110    T: Clone + Default + Send + Sync + 'static,
111{
112    /// Create a new gradient checkpointer with memory budget
113    pub fn new(memory_budget: usize) -> Self {
114        Self {
115            checkpoints: HashMap::new(),
116            memory_budget,
117            current_memory_usage: 0,
118        }
119    }
120
121    /// Store a checkpoint with estimated computation cost
122    pub fn store_checkpoint(
123        &mut self,
124        name: &str,
125        tensor: Tensor<T>,
126        computation_cost: f64,
127    ) -> Result<()> {
128        let memory_size = self.estimate_tensor_memory_size(&tensor);
129
130        // Evict checkpoints if needed to fit within budget
131        while self.current_memory_usage + memory_size > self.memory_budget
132            && !self.checkpoints.is_empty()
133        {
134            self.evict_least_valuable_checkpoint();
135        }
136
137        let checkpoint_data = CheckpointData {
138            tensor,
139            computation_cost,
140            memory_size,
141            last_accessed: std::time::Instant::now(),
142        };
143
144        if let Some(old_data) = self.checkpoints.insert(name.to_string(), checkpoint_data) {
145            self.current_memory_usage -= old_data.memory_size;
146        }
147
148        self.current_memory_usage += memory_size;
149
150        Ok(())
151    }
152
153    /// Retrieve a checkpoint
154    pub fn get_checkpoint(&mut self, name: &str) -> Option<Tensor<T>> {
155        if let Some(data) = self.checkpoints.get_mut(name) {
156            data.last_accessed = std::time::Instant::now();
157            Some(data.tensor.clone())
158        } else {
159            None
160        }
161    }
162
163    /// Check if a checkpoint exists
164    pub fn has_checkpoint(&self, name: &str) -> bool {
165        self.checkpoints.contains_key(name)
166    }
167
168    /// Evict the least valuable checkpoint based on computation cost and access time
169    fn evict_least_valuable_checkpoint(&mut self) {
170        let mut least_valuable_key = None;
171        let mut least_value_score = f64::INFINITY;
172
173        let now = std::time::Instant::now();
174
175        for (key, data) in &self.checkpoints {
176            let time_since_access = now.duration_since(data.last_accessed).as_secs_f64();
177            // Higher computation cost and recent access make checkpoints more valuable
178            let value_score = data.computation_cost / (time_since_access + 1.0);
179
180            if value_score < least_value_score {
181                least_value_score = value_score;
182                least_valuable_key = Some(key.clone());
183            }
184        }
185
186        if let Some(key) = least_valuable_key {
187            if let Some(removed_data) = self.checkpoints.remove(&key) {
188                self.current_memory_usage -= removed_data.memory_size;
189            }
190        }
191    }
192
193    /// Estimate memory size of a tensor
194    fn estimate_tensor_memory_size(&self, tensor: &Tensor<T>) -> usize {
195        let element_count: usize = tensor.shape().dims().iter().product();
196        element_count * std::mem::size_of::<T>()
197    }
198
199    /// Get checkpointing statistics
200    pub fn get_stats(&self) -> CheckpointStats {
201        CheckpointStats {
202            num_checkpoints: self.checkpoints.len(),
203            memory_usage: self.current_memory_usage,
204            memory_budget: self.memory_budget,
205            memory_utilization: self.current_memory_usage as f64 / self.memory_budget as f64,
206        }
207    }
208}
209
210/// Statistics for gradient checkpointing
211#[derive(Debug, Clone)]
212pub struct CheckpointStats {
213    pub num_checkpoints: usize,
214    pub memory_usage: usize,
215    pub memory_budget: usize,
216    pub memory_utilization: f64,
217}
218
219/// Lazy gradient computation that defers expensive operations
220pub struct LazyGradient<T> {
221    computation: Box<dyn Fn() -> Result<Tensor<T>> + Send + Sync>,
222    cached_result: Arc<Mutex<Option<Tensor<T>>>>,
223    is_expensive: bool,
224}
225
226impl<T> LazyGradient<T>
227where
228    T: Clone + Default + Send + Sync + 'static,
229{
230    /// Create a new lazy gradient computation
231    pub fn new<F>(computation: F, is_expensive: bool) -> Self
232    where
233        F: Fn() -> Result<Tensor<T>> + Send + Sync + 'static,
234    {
235        Self {
236            computation: Box::new(computation),
237            cached_result: Arc::new(Mutex::new(None)),
238            is_expensive,
239        }
240    }
241
242    /// Get the computed gradient, computing it if necessary
243    pub fn get(&self) -> Result<Tensor<T>> {
244        let mut cached = self
245            .cached_result
246            .lock()
247            .expect("lock should not be poisoned");
248
249        if let Some(result) = &*cached {
250            return Ok(result.clone());
251        }
252
253        // Compute the gradient
254        let result = (self.computation)()?;
255        *cached = Some(result.clone());
256
257        Ok(result)
258    }
259
260    /// Check if the gradient has been computed
261    pub fn is_computed(&self) -> bool {
262        self.cached_result
263            .lock()
264            .expect("lock should not be poisoned")
265            .is_some()
266    }
267
268    /// Clear cached result to free memory
269    pub fn clear_cache(&self) {
270        *self
271            .cached_result
272            .lock()
273            .expect("lock should not be poisoned") = None;
274    }
275
276    /// Check if this is an expensive computation
277    pub fn is_expensive(&self) -> bool {
278        self.is_expensive
279    }
280}
281
282/// Memory-efficient gradient aggregation with streaming computation
283pub struct StreamingGradientAggregator<T> {
284    accumulated_gradient: Option<Tensor<T>>,
285    count: usize,
286    memory_threshold: usize,
287    temp_gradients: Vec<Tensor<T>>,
288}
289
290impl<T> StreamingGradientAggregator<T>
291where
292    T: Float + Clone + Default + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
293{
294    /// Create a new streaming gradient aggregator
295    pub fn new(memory_threshold: usize) -> Self {
296        Self {
297            accumulated_gradient: None,
298            count: 0,
299            memory_threshold,
300            temp_gradients: Vec::new(),
301        }
302    }
303
304    /// Add a gradient to the aggregation
305    pub fn add_gradient(&mut self, gradient: Tensor<T>) -> Result<()> {
306        self.temp_gradients.push(gradient);
307
308        // Check if we should flush to avoid memory buildup
309        if self.temp_gradients.len() >= self.memory_threshold {
310            self.flush_temp_gradients()?;
311        }
312
313        self.count += 1;
314        Ok(())
315    }
316
317    /// Flush temporary gradients to the main accumulator
318    fn flush_temp_gradients(&mut self) -> Result<()> {
319        if self.temp_gradients.is_empty() {
320            return Ok(());
321        }
322
323        // Sum all temporary gradients
324        let mut temp_sum = self.temp_gradients[0].clone();
325        for grad in &self.temp_gradients[1..] {
326            temp_sum = temp_sum.add(grad)?;
327        }
328
329        // Add to main accumulator
330        self.accumulated_gradient = match &self.accumulated_gradient {
331            Some(acc) => Some(acc.add(&temp_sum)?),
332            None => Some(temp_sum),
333        };
334
335        // Clear temporary storage
336        self.temp_gradients.clear();
337
338        Ok(())
339    }
340
341    /// Get the final aggregated gradient
342    pub fn finalize(&mut self) -> Result<Option<Tensor<T>>> {
343        // Flush any remaining temporary gradients
344        self.flush_temp_gradients()?;
345
346        if let Some(acc_grad) = &self.accumulated_gradient {
347            if self.count > 0 {
348                let count_scalar = Tensor::from_scalar(
349                    T::from(self.count).expect("count should convert to float"),
350                );
351                let avg_grad = acc_grad.div(&count_scalar)?;
352                Ok(Some(avg_grad))
353            } else {
354                Ok(None)
355            }
356        } else {
357            Ok(None)
358        }
359    }
360
361    /// Get current aggregation statistics
362    pub fn get_stats(&self) -> AggregationStats {
363        AggregationStats {
364            total_gradients: self.count,
365            temp_gradients_count: self.temp_gradients.len(),
366            has_accumulated: self.accumulated_gradient.is_some(),
367        }
368    }
369
370    /// Reset the aggregator
371    pub fn reset(&mut self) {
372        self.accumulated_gradient = None;
373        self.count = 0;
374        self.temp_gradients.clear();
375    }
376}
377
378/// Statistics for gradient aggregation
379#[derive(Debug, Clone)]
380pub struct AggregationStats {
381    pub total_gradients: usize,
382    pub temp_gradients_count: usize,
383    pub has_accumulated: bool,
384}
385
386/// Global memory manager for gradient computations
387pub struct GradientMemoryManager<T> {
388    memory_pool: Arc<Mutex<GradientMemoryPool<T>>>,
389    checkpointer: Arc<Mutex<GradientCheckpointer<T>>>,
390    lazy_computations: Vec<LazyGradient<T>>,
391    memory_limit: usize,
392}
393
394impl<T> GradientMemoryManager<T>
395where
396    T: Clone + Default + Send + Sync + 'static + scirs2_core::num_traits::Zero,
397{
398    /// Create a new gradient memory manager
399    pub fn new(memory_limit: usize, pool_size: usize) -> Self {
400        Self {
401            memory_pool: Arc::new(Mutex::new(GradientMemoryPool::new(pool_size))),
402            checkpointer: Arc::new(Mutex::new(GradientCheckpointer::new(memory_limit / 2))),
403            lazy_computations: Vec::new(),
404            memory_limit,
405        }
406    }
407
408    /// Get a tensor from the memory pool
409    pub fn get_tensor(&self, shape: &[usize]) -> Result<Tensor<T>> {
410        let mut pool = self
411            .memory_pool
412            .lock()
413            .map_err(|_| TensorError::InvalidArgument {
414                operation: "get_tensor".to_string(),
415                reason: "Failed to acquire memory pool lock".to_string(),
416                context: None,
417            })?;
418
419        Ok(pool.get_tensor(shape))
420    }
421
422    /// Return a tensor to the memory pool
423    pub fn return_tensor(&self, tensor: Tensor<T>) -> Result<()> {
424        let mut pool = self
425            .memory_pool
426            .lock()
427            .map_err(|_| TensorError::InvalidArgument {
428                operation: "return_tensor".to_string(),
429                reason: "Failed to acquire memory pool lock".to_string(),
430                context: None,
431            })?;
432
433        pool.return_tensor(tensor);
434        Ok(())
435    }
436
437    /// Store a checkpoint
438    pub fn store_checkpoint(&self, name: &str, tensor: Tensor<T>, cost: f64) -> Result<()> {
439        let mut checkpointer =
440            self.checkpointer
441                .lock()
442                .map_err(|_| TensorError::InvalidArgument {
443                    operation: "store_checkpoint".to_string(),
444                    reason: "Failed to acquire checkpointer lock".to_string(),
445                    context: None,
446                })?;
447
448        checkpointer.store_checkpoint(name, tensor, cost)
449    }
450
451    /// Get memory usage statistics
452    pub fn get_memory_stats(&self) -> Result<MemoryManagerStats> {
453        let pool = self
454            .memory_pool
455            .lock()
456            .map_err(|_| TensorError::InvalidArgument {
457                operation: "get_memory_stats".to_string(),
458                reason: "Failed to acquire memory pool lock".to_string(),
459                context: None,
460            })?;
461
462        let checkpointer = self
463            .checkpointer
464            .lock()
465            .map_err(|_| TensorError::InvalidArgument {
466                operation: "get_memory_stats".to_string(),
467                reason: "Failed to acquire checkpointer lock".to_string(),
468                context: None,
469            })?;
470
471        Ok(MemoryManagerStats {
472            pool_stats: pool.get_stats(),
473            checkpoint_stats: checkpointer.get_stats(),
474            lazy_computations_count: self.lazy_computations.len(),
475            memory_limit: self.memory_limit,
476        })
477    }
478}
479
480/// Combined statistics for the memory manager
481#[derive(Debug, Clone)]
482pub struct MemoryManagerStats {
483    pub pool_stats: MemoryPoolStats,
484    pub checkpoint_stats: CheckpointStats,
485    pub lazy_computations_count: usize,
486    pub memory_limit: usize,
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492
493    #[test]
494    fn test_memory_pool() {
495        let mut pool = GradientMemoryPool::<f32>::new(10);
496
497        // Test getting and returning tensors
498        let tensor1 = pool.get_tensor(&[2, 3]);
499        let tensor2 = pool.get_tensor(&[2, 3]);
500
501        pool.return_tensor(tensor1);
502        let tensor3 = pool.get_tensor(&[2, 3]); // Should reuse returned tensor
503
504        let stats = pool.get_stats();
505        assert!(stats.total_allocated >= 2);
506    }
507
508    #[test]
509    fn test_checkpointer() {
510        let mut checkpointer = GradientCheckpointer::<f32>::new(1024);
511
512        let tensor = Tensor::ones(&[2, 2]);
513        checkpointer
514            .store_checkpoint("test", tensor.clone(), 10.0)
515            .expect("test: operation should succeed");
516
517        assert!(checkpointer.has_checkpoint("test"));
518        let retrieved = checkpointer
519            .get_checkpoint("test")
520            .expect("test: checkpoint operation should succeed");
521
522        // Check shapes match
523        assert_eq!(tensor.shape().dims(), retrieved.shape().dims());
524    }
525
526    #[test]
527    fn test_streaming_aggregator() {
528        let mut aggregator = StreamingGradientAggregator::<f32>::new(5);
529
530        // Add some gradients
531        for i in 0..10 {
532            let grad = Tensor::from_scalar(i as f32)
533                .broadcast_to(&[2, 2])
534                .expect("test: gradient computation should succeed");
535            aggregator
536                .add_gradient(grad)
537                .expect("test: gradient computation should succeed");
538        }
539
540        let final_grad = aggregator
541            .finalize()
542            .expect("test: gradient computation should succeed");
543        assert!(final_grad.is_some());
544
545        let stats = aggregator.get_stats();
546        assert_eq!(stats.total_gradients, 10);
547    }
548}