1use crate::{
7 Optimizer, OptimizerError, OptimizerResult, OptimizerState, ParamGroup, ParamGroupState,
8};
9use parking_lot::RwLock;
10use std::collections::HashMap;
11use std::ops::Add;
12use std::sync::Arc;
13use torsh_core::error::{Result, TorshError};
14use torsh_tensor::Tensor;
15
16pub struct MemoryPool {
18 tensors: Vec<Tensor>,
19 shapes_cache: HashMap<Vec<usize>, Vec<usize>>, }
21
22impl MemoryPool {
23 pub fn new() -> Self {
24 Self {
25 tensors: Vec::new(),
26 shapes_cache: HashMap::new(),
27 }
28 }
29
30 pub fn get_tensor(
32 &mut self,
33 shape: &[usize],
34 device: torsh_core::device::DeviceType,
35 ) -> torsh_core::error::Result<Tensor> {
36 let shape_vec = shape.to_vec();
37
38 if let Some(indices) = self.shapes_cache.get_mut(&shape_vec) {
39 if let Some(idx) = indices.pop() {
40 let mut tensor = self.tensors.swap_remove(idx);
42 let _ = tensor.zero_();
43 return Ok(tensor);
44 }
45 }
46
47 Ok(Tensor::zeros(shape, device)?)
49 }
50
51 pub fn return_tensor(&mut self, tensor: Tensor) {
53 let shape = tensor.shape().dims().to_vec();
54 let idx = self.tensors.len();
55 self.tensors.push(tensor);
56
57 self.shapes_cache.entry(shape).or_default().push(idx);
58 }
59
60 pub fn clear(&mut self) {
62 self.tensors.clear();
63 self.shapes_cache.clear();
64 }
65}
66
67#[derive(Clone)]
69pub struct MemoryConfig {
70 pub max_memory_bytes: usize,
72 pub use_memory_pool: bool,
74 pub compress_state: bool,
76 pub lazy_gradients: bool,
78 pub checkpoint_interval: usize,
80}
81
82impl Default for MemoryConfig {
83 fn default() -> Self {
84 Self {
85 max_memory_bytes: 0, use_memory_pool: true,
87 compress_state: false,
88 lazy_gradients: true,
89 checkpoint_interval: 100,
90 }
91 }
92}
93
94pub struct MemoryEfficientAdam {
96 param_groups: Vec<ParamGroup>,
97 state: HashMap<String, HashMap<String, Tensor>>,
98 step_count: usize,
99
100 beta1: f32,
102 beta2: f32,
103 eps: f32,
104 weight_decay: f32,
105 amsgrad: bool,
106
107 memory_pool: MemoryPool,
109 config: MemoryConfig,
110 memory_usage: usize,
111}
112
113impl MemoryEfficientAdam {
114 #[allow(clippy::too_many_arguments)]
115 pub fn new(
116 params: Vec<Arc<RwLock<Tensor>>>,
117 lr: f32,
118 beta1: Option<f32>,
119 beta2: Option<f32>,
120 eps: Option<f32>,
121 weight_decay: Option<f32>,
122 amsgrad: Option<bool>,
123 memory_config: Option<MemoryConfig>,
124 ) -> Self {
125 let param_group = ParamGroup::new(params, lr);
126
127 Self {
128 param_groups: vec![param_group],
129 state: HashMap::new(),
130 step_count: 0,
131 beta1: beta1.unwrap_or(0.9),
132 beta2: beta2.unwrap_or(0.999),
133 eps: eps.unwrap_or(1e-8),
134 weight_decay: weight_decay.unwrap_or(0.0),
135 amsgrad: amsgrad.unwrap_or(false),
136 memory_pool: MemoryPool::new(),
137 config: memory_config.unwrap_or_default(),
138 memory_usage: 0,
139 }
140 }
141
142 fn get_param_id(param: &Arc<RwLock<Tensor>>) -> String {
143 format!("{:p}", Arc::as_ptr(param))
144 }
145
146 fn estimate_tensor_memory(tensor: &Tensor) -> usize {
148 tensor.shape().numel() * std::mem::size_of::<f32>()
149 }
150
151 fn can_allocate(&self, size: usize) -> bool {
153 if self.config.max_memory_bytes == 0 {
154 return true; }
156 self.memory_usage + size <= self.config.max_memory_bytes
157 }
158
159 fn update_memory_usage(&mut self, delta: isize) {
161 if delta < 0 {
162 self.memory_usage = self.memory_usage.saturating_sub((-delta) as usize);
163 } else {
164 self.memory_usage += delta as usize;
165 }
166 }
167
168 fn maybe_compress_state(&mut self, param_id: &str) -> Result<()> {
170 if !self.config.compress_state {
171 return Ok(());
172 }
173
174 if let Some(param_state) = self.state.get_mut(param_id) {
176 let state_keys: Vec<String> = param_state.keys().cloned().collect();
178 for state_name in state_keys {
179 if state_name == "exp_avg"
180 || state_name == "exp_avg_sq"
181 || state_name == "max_exp_avg_sq"
182 {
183 if let Some(state_tensor) = param_state.get(&state_name).cloned() {
184 let quantized = Self::quantize_tensor(&state_tensor)?;
187 param_state.insert(state_name, quantized);
188 }
189 }
190 }
191 }
192
193 Ok(())
194 }
195
196 fn quantize_tensor(tensor: &Tensor) -> Result<Tensor> {
198 let data = tensor.to_vec()?;
203
204 let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
206 let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
207
208 let range = (max_val - min_val).max(1e-8);
210
211 let quantized_data: Vec<f32> = data
214 .iter()
215 .map(|&val| {
216 let normalized = (val - min_val) / range;
218 let quantized = (normalized * 255.0).round().clamp(0.0, 255.0) as u8;
220 min_val + (quantized as f32 / 255.0) * range
222 })
223 .collect();
224
225 let quantized_tensor = Tensor::from_data(
227 quantized_data,
228 tensor.shape().dims().to_vec(),
229 tensor.device(),
230 )?;
231
232 Ok(quantized_tensor)
233 }
234
235 fn adam_step_memory_efficient(
237 &mut self,
238 param: &Arc<RwLock<Tensor>>,
239 group_lr: f32,
240 ) -> Result<()> {
241 let param_id = Self::get_param_id(param);
242 let mut param_write = param.write();
243 let grad = param_write.grad().ok_or_else(|| {
244 TorshError::invalid_argument_with_context(
245 "Parameter has no gradient",
246 "memory_efficient_adam_step",
247 )
248 })?;
249
250 let effective_grad = if self.weight_decay > 0.0 {
252 grad.add(¶m_write.mul_scalar(self.weight_decay)?)?
253 } else {
254 grad.clone()
255 };
256
257 let momentum_key = "momentum".to_string();
259 let velocity_key = "velocity".to_string();
260 let max_exp_avg_sq_key = "max_exp_avg_sq".to_string();
261
262 let param_memory = Self::estimate_tensor_memory(¶m_write);
264
265 let needs_momentum = !self.state.contains_key(¶m_id)
267 || !self
268 .state
269 .get(¶m_id)
270 .expect("state should exist after contains_key check")
271 .contains_key(&momentum_key);
272 let needs_velocity = !self.state.contains_key(¶m_id)
273 || !self
274 .state
275 .get(¶m_id)
276 .expect("state should exist after contains_key check")
277 .contains_key(&velocity_key);
278
279 if needs_momentum && !self.can_allocate(param_memory) {
280 return Err(TorshError::invalid_argument_with_context(
281 "Insufficient memory for momentum buffer",
282 "memory_efficient_adam_step",
283 ));
284 }
285
286 if needs_velocity && !self.can_allocate(param_memory) {
287 return Err(TorshError::invalid_argument_with_context(
288 "Insufficient memory for velocity buffer",
289 "memory_efficient_adam_step",
290 ));
291 }
292
293 let memory_to_add = (if needs_momentum { 1 } else { 0 }
295 + if needs_velocity { 1 } else { 0 })
296 * param_memory;
297 if memory_to_add > 0 {
298 self.update_memory_usage(memory_to_add as isize);
299 }
300
301 let param_state = self.state.entry(param_id.clone()).or_default();
303
304 let momentum = if let Some(m) = param_state.get(&momentum_key) {
305 m.clone()
306 } else {
307 let m = if self.config.use_memory_pool {
308 self.memory_pool
309 .get_tensor(param_write.shape().dims(), param_write.device())?
310 } else {
311 Tensor::zeros(param_write.shape().dims(), param_write.device())?
312 };
313 param_state.insert(momentum_key.clone(), m.clone());
314 m
315 };
316
317 let velocity = if let Some(v) = param_state.get(&velocity_key) {
318 v.clone()
319 } else {
320 let v = if self.config.use_memory_pool {
321 self.memory_pool
322 .get_tensor(param_write.shape().dims(), param_write.device())?
323 } else {
324 Tensor::zeros(param_write.shape().dims(), param_write.device())?
325 };
326 param_state.insert(velocity_key.clone(), v.clone());
327 v
328 };
329
330 let new_momentum = momentum
332 .mul_scalar(self.beta1)?
333 .add(&effective_grad.mul_scalar(1.0 - self.beta1)?)?;
334
335 let grad_squared = effective_grad.mul_op(&effective_grad)?;
337 let new_velocity = velocity
338 .mul_scalar(self.beta2)?
339 .add(&grad_squared.mul_scalar(1.0 - self.beta2)?)?;
340
341 let bias_correction1 = 1.0 - self.beta1.powi(self.step_count as i32);
343 let bias_correction2 = 1.0 - self.beta2.powi(self.step_count as i32);
344
345 let corrected_momentum = new_momentum.div_scalar(bias_correction1)?;
346 let corrected_velocity = new_velocity.div_scalar(bias_correction2)?;
347
348 let needs_max_velocity_check =
350 self.amsgrad && !param_state.contains_key(&max_exp_avg_sq_key);
351
352 if needs_max_velocity_check {
354 let _ = param_state;
355
356 if !self.can_allocate(param_memory) {
357 return Err(TorshError::invalid_argument_with_context(
358 "Insufficient memory for max velocity buffer",
359 "memory_efficient_adam_step",
360 ));
361 }
362
363 self.update_memory_usage(param_memory as isize);
364 }
365
366 let param_state = self.state.entry(param_id.clone()).or_default();
368
369 let exp_avg_sq_hat = if self.amsgrad {
370 let max_exp_avg_sq = if let Some(max_v) = param_state.get(&max_exp_avg_sq_key) {
371 max_v.clone()
372 } else {
373 let max_v = if self.config.use_memory_pool {
374 self.memory_pool
375 .get_tensor(param_write.shape().dims(), param_write.device())?
376 } else {
377 Tensor::zeros(param_write.shape().dims(), param_write.device())?
378 };
379 param_state.insert(max_exp_avg_sq_key.clone(), max_v.clone());
380 max_v
381 };
382
383 let new_max_exp_avg_sq = max_exp_avg_sq.maximum(&corrected_velocity)?;
384 param_state.insert(max_exp_avg_sq_key, new_max_exp_avg_sq.clone());
385 new_max_exp_avg_sq
386 } else {
387 corrected_velocity.clone()
388 };
389
390 let denominator = exp_avg_sq_hat.sqrt()?.add_scalar(self.eps)?;
392 let update = corrected_momentum
393 .div(&denominator)?
394 .mul_scalar(-group_lr)?;
395
396 *param_write = param_write.add(&update)?;
398
399 param_state.insert(momentum_key, new_momentum);
401 param_state.insert(velocity_key, new_velocity);
402
403 if self.step_count % self.config.checkpoint_interval == 0 {
405 self.maybe_compress_state(¶m_id)?;
406 }
407
408 Ok(())
409 }
410}
411
412impl Optimizer for MemoryEfficientAdam {
413 fn step(&mut self) -> OptimizerResult<()> {
414 self.step_count += 1;
415
416 let param_data: Vec<(Arc<RwLock<Tensor>>, f32)> = self
418 .param_groups
419 .iter()
420 .flat_map(|group| {
421 let group_lr = group.lr;
422 group
423 .params
424 .iter()
425 .map(move |param| (param.clone(), group_lr))
426 })
427 .collect();
428
429 for (param, group_lr) in param_data {
430 self.adam_step_memory_efficient(¶m, group_lr)?;
431 }
432
433 Ok(())
434 }
435
436 fn zero_grad(&mut self) {
437 for group in &self.param_groups {
438 for param in &group.params {
439 param.write().zero_grad();
440 }
441 }
442 }
443
444 fn get_lr(&self) -> Vec<f32> {
445 self.param_groups.iter().map(|g| g.lr).collect()
446 }
447
448 fn set_lr(&mut self, lr: f32) {
449 for group in &mut self.param_groups {
450 group.lr = lr;
451 }
452 }
453
454 fn add_param_group(&mut self, params: Vec<Arc<RwLock<Tensor>>>, options: HashMap<String, f32>) {
455 let lr = options.get("lr").copied().unwrap_or(1e-3);
456 let group = ParamGroup::new(params, lr).with_options(options);
457 self.param_groups.push(group);
458 }
459
460 fn state_dict(&self) -> OptimizerResult<OptimizerState> {
461 let param_groups = self
462 .param_groups
463 .iter()
464 .map(|g| ParamGroupState {
465 lr: g.lr,
466 options: g.options.clone(),
467 param_count: g.params.len(),
468 })
469 .collect();
470
471 Ok(OptimizerState {
472 optimizer_type: "MemoryEfficientAdam".to_string(),
473 version: "0.1.0".to_string(),
474 param_groups,
475 state: self.state.clone(),
476 global_state: HashMap::new(),
477 })
478 }
479
480 fn load_state_dict(&mut self, state: OptimizerState) -> OptimizerResult<()> {
481 if state.param_groups.len() != self.param_groups.len() {
482 return Err(OptimizerError::InvalidParameter(
483 "Parameter group count mismatch".to_string(),
484 ));
485 }
486
487 for (i, group_state) in state.param_groups.iter().enumerate() {
488 self.param_groups[i].lr = group_state.lr;
489 self.param_groups[i].options = group_state.options.clone();
490 }
491
492 self.state = state.state;
493 Ok(())
494 }
495}
496
497pub struct MemoryEfficientLBFGS {
499 param_groups: Vec<ParamGroup>,
500 state: HashMap<String, HashMap<String, Tensor>>,
501 step_count: usize,
502
503 max_iter: usize,
505 tolerance_grad: f32,
506 tolerance_change: f32,
507 history_size: usize,
508
509 memory_pool: MemoryPool,
511 config: MemoryConfig,
512 history_buffer: CircularBuffer<(Tensor, Tensor, f32)>, }
514
515pub struct CircularBuffer<T> {
517 data: Vec<T>,
518 capacity: usize,
519 start: usize,
520 len: usize,
521}
522
523impl<T> CircularBuffer<T> {
524 pub fn new(capacity: usize) -> Self {
525 Self {
526 data: Vec::with_capacity(capacity),
527 capacity,
528 start: 0,
529 len: 0,
530 }
531 }
532
533 pub fn push(&mut self, item: T) {
534 if self.len < self.capacity {
535 self.data.push(item);
536 self.len += 1;
537 } else {
538 let index = (self.start + self.len) % self.capacity;
539 self.data[index] = item;
540 self.start = (self.start + 1) % self.capacity;
541 }
542 }
543
544 pub fn get(&self, index: usize) -> Option<&T> {
545 if index < self.len {
546 let actual_index = (self.start + index) % self.capacity;
547 self.data.get(actual_index)
548 } else {
549 None
550 }
551 }
552
553 pub fn len(&self) -> usize {
554 self.len
555 }
556
557 pub fn is_empty(&self) -> bool {
558 self.len == 0
559 }
560
561 pub fn clear(&mut self) {
562 self.data.clear();
563 self.start = 0;
564 self.len = 0;
565 }
566}
567
568impl MemoryEfficientLBFGS {
569 pub fn new(
570 params: Vec<Arc<RwLock<Tensor>>>,
571 lr: Option<f32>,
572 max_iter: Option<usize>,
573 tolerance_grad: Option<f32>,
574 tolerance_change: Option<f32>,
575 history_size: Option<usize>,
576 memory_config: Option<MemoryConfig>,
577 ) -> Self {
578 let lr = lr.unwrap_or(1.0);
579 let history_size = history_size.unwrap_or(10); let param_group = ParamGroup::new(params, lr);
582
583 Self {
584 param_groups: vec![param_group],
585 state: HashMap::new(),
586 step_count: 0,
587 max_iter: max_iter.unwrap_or(20),
588 tolerance_grad: tolerance_grad.unwrap_or(1e-7),
589 tolerance_change: tolerance_change.unwrap_or(1e-9),
590 history_size,
591 memory_pool: MemoryPool::new(),
592 config: memory_config.unwrap_or_default(),
593 history_buffer: CircularBuffer::new(history_size),
594 }
595 }
596
597 pub fn memory_stats(&self) -> HashMap<String, usize> {
599 let mut stats = HashMap::new();
600 stats.insert(
601 "total_usage".to_string(),
602 self.memory_pool.tensors.len() * std::mem::size_of::<Tensor>(),
603 );
604 stats.insert("history_size".to_string(), self.history_buffer.len());
605 stats.insert("pooled_tensors".to_string(), self.memory_pool.tensors.len());
606 stats
607 }
608
609 pub fn clear_memory(&mut self) {
611 self.memory_pool.clear();
612 self.history_buffer.clear();
613 self.state.clear();
614 }
615}
616
617impl Optimizer for MemoryEfficientLBFGS {
618 fn step(&mut self) -> OptimizerResult<()> {
619 self.step_count += 1;
622
623 for group in &self.param_groups {
625 for param in &group.params {
626 let mut param_write = param.write();
627 if let Some(grad) = param_write.grad() {
628 let update = grad.mul_scalar(-group.lr)?;
629 *param_write = param_write.add(&update)?;
630 }
631 }
632 }
633
634 Ok(())
635 }
636
637 fn zero_grad(&mut self) {
638 for group in &self.param_groups {
639 for param in &group.params {
640 param.write().zero_grad();
641 }
642 }
643 }
644
645 fn get_lr(&self) -> Vec<f32> {
646 self.param_groups.iter().map(|g| g.lr).collect()
647 }
648
649 fn set_lr(&mut self, lr: f32) {
650 for group in &mut self.param_groups {
651 group.lr = lr;
652 }
653 }
654
655 fn add_param_group(&mut self, params: Vec<Arc<RwLock<Tensor>>>, options: HashMap<String, f32>) {
656 let lr = options.get("lr").copied().unwrap_or(1.0);
657 let group = ParamGroup::new(params, lr).with_options(options);
658 self.param_groups.push(group);
659 }
660
661 fn state_dict(&self) -> OptimizerResult<OptimizerState> {
662 let param_groups = self
663 .param_groups
664 .iter()
665 .map(|g| ParamGroupState {
666 lr: g.lr,
667 options: g.options.clone(),
668 param_count: g.params.len(),
669 })
670 .collect();
671
672 Ok(OptimizerState {
673 optimizer_type: "MemoryEfficientAdam".to_string(),
674 version: "0.1.0".to_string(),
675 param_groups,
676 state: self.state.clone(),
677 global_state: HashMap::new(),
678 })
679 }
680
681 fn load_state_dict(&mut self, state: OptimizerState) -> OptimizerResult<()> {
682 if state.param_groups.len() != self.param_groups.len() {
683 return Err(OptimizerError::InvalidParameter(
684 "Parameter group count mismatch".to_string(),
685 ));
686 }
687
688 for (i, group_state) in state.param_groups.iter().enumerate() {
689 self.param_groups[i].lr = group_state.lr;
690 self.param_groups[i].options = group_state.options.clone();
691 }
692
693 self.state = state.state;
694 Ok(())
695 }
696}
697
698pub struct MemoryEfficientOptimizerBuilder {
700 memory_config: MemoryConfig,
701}
702
703impl MemoryEfficientOptimizerBuilder {
704 pub fn new() -> Self {
705 Self {
706 memory_config: MemoryConfig::default(),
707 }
708 }
709
710 pub fn max_memory_gb(mut self, gb: f32) -> Self {
711 self.memory_config.max_memory_bytes = (gb * 1_000_000_000.0) as usize;
712 self
713 }
714
715 pub fn use_memory_pool(mut self, use_pool: bool) -> Self {
716 self.memory_config.use_memory_pool = use_pool;
717 self
718 }
719
720 pub fn compress_state(mut self, compress: bool) -> Self {
721 self.memory_config.compress_state = compress;
722 self
723 }
724
725 pub fn lazy_gradients(mut self, lazy: bool) -> Self {
726 self.memory_config.lazy_gradients = lazy;
727 self
728 }
729
730 pub fn checkpoint_interval(mut self, interval: usize) -> Self {
731 self.memory_config.checkpoint_interval = interval;
732 self
733 }
734
735 pub fn build_adam(self, params: Vec<Arc<RwLock<Tensor>>>, lr: f32) -> MemoryEfficientAdam {
736 MemoryEfficientAdam::new(
737 params,
738 lr,
739 None,
740 None,
741 None,
742 None,
743 None,
744 Some(self.memory_config),
745 )
746 }
747
748 pub fn build_lbfgs(self, params: Vec<Arc<RwLock<Tensor>>>, lr: f32) -> MemoryEfficientLBFGS {
749 MemoryEfficientLBFGS::new(
750 params,
751 Some(lr),
752 None,
753 None,
754 None,
755 None,
756 Some(self.memory_config),
757 )
758 }
759}
760
761impl Default for MemoryEfficientOptimizerBuilder {
762 fn default() -> Self {
763 Self::new()
764 }
765}
766
767#[cfg(test)]
768mod tests {
769 use super::*;
770 use torsh_tensor::creation::randn;
771
772 #[test]
773 fn test_memory_pool() -> OptimizerResult<()> {
774 let mut pool = MemoryPool::new();
775 let tensor = pool.get_tensor(&[2, 2], torsh_core::device::DeviceType::Cpu)?;
776 assert_eq!(tensor.shape().dims(), &[2, 2]);
777
778 pool.return_tensor(tensor);
779 let reused = pool.get_tensor(&[2, 2], torsh_core::device::DeviceType::Cpu)?;
780 assert_eq!(reused.shape().dims(), &[2, 2]);
781 Ok(())
782 }
783
784 #[test]
785 fn test_circular_buffer() {
786 let mut buffer = CircularBuffer::new(3);
787 buffer.push(1);
788 buffer.push(2);
789 buffer.push(3);
790 buffer.push(4); assert_eq!(buffer.len(), 3);
793 assert_eq!(buffer.get(0), Some(&2));
794 assert_eq!(buffer.get(1), Some(&3));
795 assert_eq!(buffer.get(2), Some(&4));
796 }
797
798 #[test]
799 fn test_memory_efficient_adam_creation() -> OptimizerResult<()> {
800 let params = vec![Arc::new(RwLock::new(randn::<f32>(&[2, 2])?))];
801 let optimizer = MemoryEfficientAdam::new(params, 0.001, None, None, None, None, None, None);
802 assert_eq!(optimizer.get_lr()[0], 0.001);
803 Ok(())
804 }
805
806 #[test]
807 fn test_memory_efficient_lbfgs_creation() -> OptimizerResult<()> {
808 let params = vec![Arc::new(RwLock::new(randn::<f32>(&[2, 2])?))];
809 let optimizer =
810 MemoryEfficientLBFGS::new(params, Some(0.1), None, None, None, Some(5), None);
811 assert_eq!(optimizer.get_lr()[0], 0.1);
812 assert_eq!(optimizer.history_size, 5);
813 Ok(())
814 }
815
816 #[test]
817 fn test_builder_pattern() -> OptimizerResult<()> {
818 let params = vec![Arc::new(RwLock::new(randn::<f32>(&[2, 2])?))];
819 let optimizer = MemoryEfficientOptimizerBuilder::new()
820 .max_memory_gb(1.0)
821 .use_memory_pool(true)
822 .compress_state(true)
823 .build_adam(params, 0.001);
824
825 assert_eq!(optimizer.get_lr()[0], 0.001);
826 assert_eq!(optimizer.config.max_memory_bytes, 1_000_000_000);
827 assert!(optimizer.config.use_memory_pool);
828 assert!(optimizer.config.compress_state);
829 Ok(())
830 }
831}