1use scirs2_core::numeric::Float;
7use std::collections::{HashMap, VecDeque};
8use std::sync::{Arc, Mutex};
9use tenflowers_core::{Result, Tensor, TensorError};
10
11pub struct GradientMemoryPool<T> {
13 available_tensors: HashMap<Vec<usize>, VecDeque<Tensor<T>>>,
14 max_pool_size: usize,
15 total_allocated: usize,
16}
17
18impl<T> GradientMemoryPool<T>
19where
20 T: Clone + Default + Send + Sync + 'static + scirs2_core::num_traits::Zero,
21{
22 pub fn new(max_pool_size: usize) -> Self {
24 Self {
25 available_tensors: HashMap::new(),
26 max_pool_size,
27 total_allocated: 0,
28 }
29 }
30
31 pub fn get_tensor(&mut self, shape: &[usize]) -> Tensor<T> {
33 let shape_vec = shape.to_vec();
34
35 if let Some(tensor_queue) = self.available_tensors.get_mut(&shape_vec) {
36 if let Some(tensor) = tensor_queue.pop_front() {
37 return tensor;
38 }
39 }
40
41 self.total_allocated += 1;
43 Tensor::zeros(shape)
44 }
45
46 pub fn return_tensor(&mut self, tensor: Tensor<T>) {
48 let shape = tensor.shape().dims().to_vec();
49
50 let tensor_queue = self.available_tensors.entry(shape).or_default();
51
52 if tensor_queue.len() < self.max_pool_size {
53 tensor_queue.push_back(tensor);
54 }
55 }
57
58 pub fn get_stats(&self) -> MemoryPoolStats {
60 let total_pooled: usize = self
61 .available_tensors
62 .values()
63 .map(|queue| queue.len())
64 .sum();
65
66 MemoryPoolStats {
67 total_allocated: self.total_allocated,
68 total_pooled,
69 pool_hit_ratio: if self.total_allocated > 0 {
70 total_pooled as f64 / self.total_allocated as f64
71 } else {
72 0.0
73 },
74 }
75 }
76
77 pub fn clear(&mut self) {
79 self.available_tensors.clear();
80 self.total_allocated = 0;
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct MemoryPoolStats {
87 pub total_allocated: usize,
88 pub total_pooled: usize,
89 pub pool_hit_ratio: f64,
90}
91
92pub struct GradientCheckpointer<T> {
94 checkpoints: HashMap<String, CheckpointData<T>>,
95 memory_budget: usize,
96 current_memory_usage: usize,
97}
98
99#[derive(Clone)]
101struct CheckpointData<T> {
102 tensor: Tensor<T>,
103 computation_cost: f64,
104 memory_size: usize,
105 last_accessed: std::time::Instant,
106}
107
108impl<T> GradientCheckpointer<T>
109where
110 T: Clone + Default + Send + Sync + 'static,
111{
112 pub fn new(memory_budget: usize) -> Self {
114 Self {
115 checkpoints: HashMap::new(),
116 memory_budget,
117 current_memory_usage: 0,
118 }
119 }
120
121 pub fn store_checkpoint(
123 &mut self,
124 name: &str,
125 tensor: Tensor<T>,
126 computation_cost: f64,
127 ) -> Result<()> {
128 let memory_size = self.estimate_tensor_memory_size(&tensor);
129
130 while self.current_memory_usage + memory_size > self.memory_budget
132 && !self.checkpoints.is_empty()
133 {
134 self.evict_least_valuable_checkpoint();
135 }
136
137 let checkpoint_data = CheckpointData {
138 tensor,
139 computation_cost,
140 memory_size,
141 last_accessed: std::time::Instant::now(),
142 };
143
144 if let Some(old_data) = self.checkpoints.insert(name.to_string(), checkpoint_data) {
145 self.current_memory_usage -= old_data.memory_size;
146 }
147
148 self.current_memory_usage += memory_size;
149
150 Ok(())
151 }
152
153 pub fn get_checkpoint(&mut self, name: &str) -> Option<Tensor<T>> {
155 if let Some(data) = self.checkpoints.get_mut(name) {
156 data.last_accessed = std::time::Instant::now();
157 Some(data.tensor.clone())
158 } else {
159 None
160 }
161 }
162
163 pub fn has_checkpoint(&self, name: &str) -> bool {
165 self.checkpoints.contains_key(name)
166 }
167
168 fn evict_least_valuable_checkpoint(&mut self) {
170 let mut least_valuable_key = None;
171 let mut least_value_score = f64::INFINITY;
172
173 let now = std::time::Instant::now();
174
175 for (key, data) in &self.checkpoints {
176 let time_since_access = now.duration_since(data.last_accessed).as_secs_f64();
177 let value_score = data.computation_cost / (time_since_access + 1.0);
179
180 if value_score < least_value_score {
181 least_value_score = value_score;
182 least_valuable_key = Some(key.clone());
183 }
184 }
185
186 if let Some(key) = least_valuable_key {
187 if let Some(removed_data) = self.checkpoints.remove(&key) {
188 self.current_memory_usage -= removed_data.memory_size;
189 }
190 }
191 }
192
193 fn estimate_tensor_memory_size(&self, tensor: &Tensor<T>) -> usize {
195 let element_count: usize = tensor.shape().dims().iter().product();
196 element_count * std::mem::size_of::<T>()
197 }
198
199 pub fn get_stats(&self) -> CheckpointStats {
201 CheckpointStats {
202 num_checkpoints: self.checkpoints.len(),
203 memory_usage: self.current_memory_usage,
204 memory_budget: self.memory_budget,
205 memory_utilization: self.current_memory_usage as f64 / self.memory_budget as f64,
206 }
207 }
208}
209
210#[derive(Debug, Clone)]
212pub struct CheckpointStats {
213 pub num_checkpoints: usize,
214 pub memory_usage: usize,
215 pub memory_budget: usize,
216 pub memory_utilization: f64,
217}
218
219pub struct LazyGradient<T> {
221 computation: Box<dyn Fn() -> Result<Tensor<T>> + Send + Sync>,
222 cached_result: Arc<Mutex<Option<Tensor<T>>>>,
223 is_expensive: bool,
224}
225
226impl<T> LazyGradient<T>
227where
228 T: Clone + Default + Send + Sync + 'static,
229{
230 pub fn new<F>(computation: F, is_expensive: bool) -> Self
232 where
233 F: Fn() -> Result<Tensor<T>> + Send + Sync + 'static,
234 {
235 Self {
236 computation: Box::new(computation),
237 cached_result: Arc::new(Mutex::new(None)),
238 is_expensive,
239 }
240 }
241
242 pub fn get(&self) -> Result<Tensor<T>> {
244 let mut cached = self
245 .cached_result
246 .lock()
247 .expect("lock should not be poisoned");
248
249 if let Some(result) = &*cached {
250 return Ok(result.clone());
251 }
252
253 let result = (self.computation)()?;
255 *cached = Some(result.clone());
256
257 Ok(result)
258 }
259
260 pub fn is_computed(&self) -> bool {
262 self.cached_result
263 .lock()
264 .expect("lock should not be poisoned")
265 .is_some()
266 }
267
268 pub fn clear_cache(&self) {
270 *self
271 .cached_result
272 .lock()
273 .expect("lock should not be poisoned") = None;
274 }
275
276 pub fn is_expensive(&self) -> bool {
278 self.is_expensive
279 }
280}
281
282pub struct StreamingGradientAggregator<T> {
284 accumulated_gradient: Option<Tensor<T>>,
285 count: usize,
286 memory_threshold: usize,
287 temp_gradients: Vec<Tensor<T>>,
288}
289
290impl<T> StreamingGradientAggregator<T>
291where
292 T: Float + Clone + Default + Send + Sync + 'static + bytemuck::Pod + bytemuck::Zeroable,
293{
294 pub fn new(memory_threshold: usize) -> Self {
296 Self {
297 accumulated_gradient: None,
298 count: 0,
299 memory_threshold,
300 temp_gradients: Vec::new(),
301 }
302 }
303
304 pub fn add_gradient(&mut self, gradient: Tensor<T>) -> Result<()> {
306 self.temp_gradients.push(gradient);
307
308 if self.temp_gradients.len() >= self.memory_threshold {
310 self.flush_temp_gradients()?;
311 }
312
313 self.count += 1;
314 Ok(())
315 }
316
317 fn flush_temp_gradients(&mut self) -> Result<()> {
319 if self.temp_gradients.is_empty() {
320 return Ok(());
321 }
322
323 let mut temp_sum = self.temp_gradients[0].clone();
325 for grad in &self.temp_gradients[1..] {
326 temp_sum = temp_sum.add(grad)?;
327 }
328
329 self.accumulated_gradient = match &self.accumulated_gradient {
331 Some(acc) => Some(acc.add(&temp_sum)?),
332 None => Some(temp_sum),
333 };
334
335 self.temp_gradients.clear();
337
338 Ok(())
339 }
340
341 pub fn finalize(&mut self) -> Result<Option<Tensor<T>>> {
343 self.flush_temp_gradients()?;
345
346 if let Some(acc_grad) = &self.accumulated_gradient {
347 if self.count > 0 {
348 let count_scalar = Tensor::from_scalar(
349 T::from(self.count).expect("count should convert to float"),
350 );
351 let avg_grad = acc_grad.div(&count_scalar)?;
352 Ok(Some(avg_grad))
353 } else {
354 Ok(None)
355 }
356 } else {
357 Ok(None)
358 }
359 }
360
361 pub fn get_stats(&self) -> AggregationStats {
363 AggregationStats {
364 total_gradients: self.count,
365 temp_gradients_count: self.temp_gradients.len(),
366 has_accumulated: self.accumulated_gradient.is_some(),
367 }
368 }
369
370 pub fn reset(&mut self) {
372 self.accumulated_gradient = None;
373 self.count = 0;
374 self.temp_gradients.clear();
375 }
376}
377
378#[derive(Debug, Clone)]
380pub struct AggregationStats {
381 pub total_gradients: usize,
382 pub temp_gradients_count: usize,
383 pub has_accumulated: bool,
384}
385
386pub struct GradientMemoryManager<T> {
388 memory_pool: Arc<Mutex<GradientMemoryPool<T>>>,
389 checkpointer: Arc<Mutex<GradientCheckpointer<T>>>,
390 lazy_computations: Vec<LazyGradient<T>>,
391 memory_limit: usize,
392}
393
394impl<T> GradientMemoryManager<T>
395where
396 T: Clone + Default + Send + Sync + 'static + scirs2_core::num_traits::Zero,
397{
398 pub fn new(memory_limit: usize, pool_size: usize) -> Self {
400 Self {
401 memory_pool: Arc::new(Mutex::new(GradientMemoryPool::new(pool_size))),
402 checkpointer: Arc::new(Mutex::new(GradientCheckpointer::new(memory_limit / 2))),
403 lazy_computations: Vec::new(),
404 memory_limit,
405 }
406 }
407
408 pub fn get_tensor(&self, shape: &[usize]) -> Result<Tensor<T>> {
410 let mut pool = self
411 .memory_pool
412 .lock()
413 .map_err(|_| TensorError::InvalidArgument {
414 operation: "get_tensor".to_string(),
415 reason: "Failed to acquire memory pool lock".to_string(),
416 context: None,
417 })?;
418
419 Ok(pool.get_tensor(shape))
420 }
421
422 pub fn return_tensor(&self, tensor: Tensor<T>) -> Result<()> {
424 let mut pool = self
425 .memory_pool
426 .lock()
427 .map_err(|_| TensorError::InvalidArgument {
428 operation: "return_tensor".to_string(),
429 reason: "Failed to acquire memory pool lock".to_string(),
430 context: None,
431 })?;
432
433 pool.return_tensor(tensor);
434 Ok(())
435 }
436
437 pub fn store_checkpoint(&self, name: &str, tensor: Tensor<T>, cost: f64) -> Result<()> {
439 let mut checkpointer =
440 self.checkpointer
441 .lock()
442 .map_err(|_| TensorError::InvalidArgument {
443 operation: "store_checkpoint".to_string(),
444 reason: "Failed to acquire checkpointer lock".to_string(),
445 context: None,
446 })?;
447
448 checkpointer.store_checkpoint(name, tensor, cost)
449 }
450
451 pub fn get_memory_stats(&self) -> Result<MemoryManagerStats> {
453 let pool = self
454 .memory_pool
455 .lock()
456 .map_err(|_| TensorError::InvalidArgument {
457 operation: "get_memory_stats".to_string(),
458 reason: "Failed to acquire memory pool lock".to_string(),
459 context: None,
460 })?;
461
462 let checkpointer = self
463 .checkpointer
464 .lock()
465 .map_err(|_| TensorError::InvalidArgument {
466 operation: "get_memory_stats".to_string(),
467 reason: "Failed to acquire checkpointer lock".to_string(),
468 context: None,
469 })?;
470
471 Ok(MemoryManagerStats {
472 pool_stats: pool.get_stats(),
473 checkpoint_stats: checkpointer.get_stats(),
474 lazy_computations_count: self.lazy_computations.len(),
475 memory_limit: self.memory_limit,
476 })
477 }
478}
479
480#[derive(Debug, Clone)]
482pub struct MemoryManagerStats {
483 pub pool_stats: MemoryPoolStats,
484 pub checkpoint_stats: CheckpointStats,
485 pub lazy_computations_count: usize,
486 pub memory_limit: usize,
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 #[test]
494 fn test_memory_pool() {
495 let mut pool = GradientMemoryPool::<f32>::new(10);
496
497 let tensor1 = pool.get_tensor(&[2, 3]);
499 let tensor2 = pool.get_tensor(&[2, 3]);
500
501 pool.return_tensor(tensor1);
502 let tensor3 = pool.get_tensor(&[2, 3]); let stats = pool.get_stats();
505 assert!(stats.total_allocated >= 2);
506 }
507
508 #[test]
509 fn test_checkpointer() {
510 let mut checkpointer = GradientCheckpointer::<f32>::new(1024);
511
512 let tensor = Tensor::ones(&[2, 2]);
513 checkpointer
514 .store_checkpoint("test", tensor.clone(), 10.0)
515 .expect("test: operation should succeed");
516
517 assert!(checkpointer.has_checkpoint("test"));
518 let retrieved = checkpointer
519 .get_checkpoint("test")
520 .expect("test: checkpoint operation should succeed");
521
522 assert_eq!(tensor.shape().dims(), retrieved.shape().dims());
524 }
525
526 #[test]
527 fn test_streaming_aggregator() {
528 let mut aggregator = StreamingGradientAggregator::<f32>::new(5);
529
530 for i in 0..10 {
532 let grad = Tensor::from_scalar(i as f32)
533 .broadcast_to(&[2, 2])
534 .expect("test: gradient computation should succeed");
535 aggregator
536 .add_gradient(grad)
537 .expect("test: gradient computation should succeed");
538 }
539
540 let final_grad = aggregator
541 .finalize()
542 .expect("test: gradient computation should succeed");
543 assert!(final_grad.is_some());
544
545 let stats = aggregator.get_stats();
546 assert_eq!(stats.total_gradients, 10);
547 }
548}