Skip to main content

torsh_optim/
memory_efficient.rs

1//! Memory-efficient optimizer implementations
2//!
3//! This module provides optimizers designed to minimize memory usage during training,
4//! particularly useful for large models or resource-constrained environments.
5
6use crate::{
7    Optimizer, OptimizerError, OptimizerResult, OptimizerState, ParamGroup, ParamGroupState,
8};
9use parking_lot::RwLock;
10use std::collections::HashMap;
11use std::ops::Add;
12use std::sync::Arc;
13use torsh_core::error::{Result, TorshError};
14use torsh_tensor::Tensor;
15
16/// Memory pool for reusing tensor allocations
17pub struct MemoryPool {
18    tensors: Vec<Tensor>,
19    shapes_cache: HashMap<Vec<usize>, Vec<usize>>, // shape -> indices in tensors vec
20}
21
22impl MemoryPool {
23    pub fn new() -> Self {
24        Self {
25            tensors: Vec::new(),
26            shapes_cache: HashMap::new(),
27        }
28    }
29
30    /// Get a tensor from the pool or allocate a new one
31    pub fn get_tensor(
32        &mut self,
33        shape: &[usize],
34        device: torsh_core::device::DeviceType,
35    ) -> torsh_core::error::Result<Tensor> {
36        let shape_vec = shape.to_vec();
37
38        if let Some(indices) = self.shapes_cache.get_mut(&shape_vec) {
39            if let Some(idx) = indices.pop() {
40                // Reuse existing tensor
41                let mut tensor = self.tensors.swap_remove(idx);
42                let _ = tensor.zero_();
43                return Ok(tensor);
44            }
45        }
46
47        // Allocate new tensor
48        Ok(Tensor::zeros(shape, device)?)
49    }
50
51    /// Return a tensor to the pool
52    pub fn return_tensor(&mut self, tensor: Tensor) {
53        let shape = tensor.shape().dims().to_vec();
54        let idx = self.tensors.len();
55        self.tensors.push(tensor);
56
57        self.shapes_cache.entry(shape).or_default().push(idx);
58    }
59
60    /// Clear the pool to free memory
61    pub fn clear(&mut self) {
62        self.tensors.clear();
63        self.shapes_cache.clear();
64    }
65}
66
67/// Configuration for memory-efficient optimizers
68#[derive(Clone)]
69pub struct MemoryConfig {
70    /// Maximum memory usage in bytes (0 = unlimited)
71    pub max_memory_bytes: usize,
72    /// Use memory pooling
73    pub use_memory_pool: bool,
74    /// Enable state compression
75    pub compress_state: bool,
76    /// Use lazy gradient accumulation
77    pub lazy_gradients: bool,
78    /// Gradient checkpointing interval
79    pub checkpoint_interval: usize,
80}
81
82impl Default for MemoryConfig {
83    fn default() -> Self {
84        Self {
85            max_memory_bytes: 0, // Unlimited
86            use_memory_pool: true,
87            compress_state: false,
88            lazy_gradients: true,
89            checkpoint_interval: 100,
90        }
91    }
92}
93
94/// Memory-efficient Adam optimizer with reduced memory footprint
95pub struct MemoryEfficientAdam {
96    param_groups: Vec<ParamGroup>,
97    state: HashMap<String, HashMap<String, Tensor>>,
98    step_count: usize,
99
100    // Adam parameters
101    beta1: f32,
102    beta2: f32,
103    eps: f32,
104    weight_decay: f32,
105    amsgrad: bool,
106
107    // Memory optimization
108    memory_pool: MemoryPool,
109    config: MemoryConfig,
110    memory_usage: usize,
111}
112
113impl MemoryEfficientAdam {
114    #[allow(clippy::too_many_arguments)]
115    pub fn new(
116        params: Vec<Arc<RwLock<Tensor>>>,
117        lr: f32,
118        beta1: Option<f32>,
119        beta2: Option<f32>,
120        eps: Option<f32>,
121        weight_decay: Option<f32>,
122        amsgrad: Option<bool>,
123        memory_config: Option<MemoryConfig>,
124    ) -> Self {
125        let param_group = ParamGroup::new(params, lr);
126
127        Self {
128            param_groups: vec![param_group],
129            state: HashMap::new(),
130            step_count: 0,
131            beta1: beta1.unwrap_or(0.9),
132            beta2: beta2.unwrap_or(0.999),
133            eps: eps.unwrap_or(1e-8),
134            weight_decay: weight_decay.unwrap_or(0.0),
135            amsgrad: amsgrad.unwrap_or(false),
136            memory_pool: MemoryPool::new(),
137            config: memory_config.unwrap_or_default(),
138            memory_usage: 0,
139        }
140    }
141
142    fn get_param_id(param: &Arc<RwLock<Tensor>>) -> String {
143        format!("{:p}", Arc::as_ptr(param))
144    }
145
146    /// Estimate memory usage for a tensor
147    fn estimate_tensor_memory(tensor: &Tensor) -> usize {
148        tensor.shape().numel() * std::mem::size_of::<f32>()
149    }
150
151    /// Check if we can allocate more memory
152    fn can_allocate(&self, size: usize) -> bool {
153        if self.config.max_memory_bytes == 0 {
154            return true; // Unlimited
155        }
156        self.memory_usage + size <= self.config.max_memory_bytes
157    }
158
159    /// Update memory usage tracking
160    fn update_memory_usage(&mut self, delta: isize) {
161        if delta < 0 {
162            self.memory_usage = self.memory_usage.saturating_sub((-delta) as usize);
163        } else {
164            self.memory_usage += delta as usize;
165        }
166    }
167
168    /// Compress state if needed
169    fn maybe_compress_state(&mut self, param_id: &str) -> Result<()> {
170        if !self.config.compress_state {
171            return Ok(());
172        }
173
174        // Get the state for this parameter
175        if let Some(param_state) = self.state.get_mut(param_id) {
176            // Compress momentum and squared gradient states using quantization
177            let state_keys: Vec<String> = param_state.keys().cloned().collect();
178            for state_name in state_keys {
179                if state_name == "exp_avg"
180                    || state_name == "exp_avg_sq"
181                    || state_name == "max_exp_avg_sq"
182                {
183                    if let Some(state_tensor) = param_state.get(&state_name).cloned() {
184                        // Apply quantization to reduce memory usage
185                        // Convert to lower precision and back to maintain approximate values
186                        let quantized = Self::quantize_tensor(&state_tensor)?;
187                        param_state.insert(state_name, quantized);
188                    }
189                }
190            }
191        }
192
193        Ok(())
194    }
195
196    /// Quantize tensor to reduce memory usage
197    fn quantize_tensor(tensor: &Tensor) -> Result<Tensor> {
198        // Simple quantization scheme: convert to f16 precision and back to f32
199        // This reduces memory usage by approximately 50% for state tensors
200
201        // Get tensor data as f32 values
202        let data = tensor.to_vec()?;
203
204        // Find min and max values for quantization scaling
205        let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
206        let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
207
208        // Avoid division by zero
209        let range = (max_val - min_val).max(1e-8);
210
211        // Quantize to 8-bit integers and dequantize back to f32
212        // This provides significant compression with acceptable precision loss
213        let quantized_data: Vec<f32> = data
214            .iter()
215            .map(|&val| {
216                // Normalize to [0, 1]
217                let normalized = (val - min_val) / range;
218                // Quantize to 8-bit (0-255)
219                let quantized = (normalized * 255.0).round().clamp(0.0, 255.0) as u8;
220                // Dequantize back to f32
221                min_val + (quantized as f32 / 255.0) * range
222            })
223            .collect();
224
225        // Create new tensor with quantized data
226        let quantized_tensor = Tensor::from_data(
227            quantized_data,
228            tensor.shape().dims().to_vec(),
229            tensor.device(),
230        )?;
231
232        Ok(quantized_tensor)
233    }
234
235    /// Perform memory-efficient Adam step for a single parameter
236    fn adam_step_memory_efficient(
237        &mut self,
238        param: &Arc<RwLock<Tensor>>,
239        group_lr: f32,
240    ) -> Result<()> {
241        let param_id = Self::get_param_id(param);
242        let mut param_write = param.write();
243        let grad = param_write.grad().ok_or_else(|| {
244            TorshError::invalid_argument_with_context(
245                "Parameter has no gradient",
246                "memory_efficient_adam_step",
247            )
248        })?;
249
250        // Apply weight decay if specified
251        let effective_grad = if self.weight_decay > 0.0 {
252            grad.add(&param_write.mul_scalar(self.weight_decay)?)?
253        } else {
254            grad.clone()
255        };
256
257        // Get or initialize momentum buffers
258        let momentum_key = "momentum".to_string();
259        let velocity_key = "velocity".to_string();
260        let max_exp_avg_sq_key = "max_exp_avg_sq".to_string();
261
262        // Check memory constraints before allocation
263        let param_memory = Self::estimate_tensor_memory(&param_write);
264
265        // Check if we need to allocate memory first
266        let needs_momentum = !self.state.contains_key(&param_id)
267            || !self
268                .state
269                .get(&param_id)
270                .expect("state should exist after contains_key check")
271                .contains_key(&momentum_key);
272        let needs_velocity = !self.state.contains_key(&param_id)
273            || !self
274                .state
275                .get(&param_id)
276                .expect("state should exist after contains_key check")
277                .contains_key(&velocity_key);
278
279        if needs_momentum && !self.can_allocate(param_memory) {
280            return Err(TorshError::invalid_argument_with_context(
281                "Insufficient memory for momentum buffer",
282                "memory_efficient_adam_step",
283            ));
284        }
285
286        if needs_velocity && !self.can_allocate(param_memory) {
287            return Err(TorshError::invalid_argument_with_context(
288                "Insufficient memory for velocity buffer",
289                "memory_efficient_adam_step",
290            ));
291        }
292
293        // Update memory usage for new allocations
294        let memory_to_add = (if needs_momentum { 1 } else { 0 }
295            + if needs_velocity { 1 } else { 0 })
296            * param_memory;
297        if memory_to_add > 0 {
298            self.update_memory_usage(memory_to_add as isize);
299        }
300
301        // Now safely get the parameter state
302        let param_state = self.state.entry(param_id.clone()).or_default();
303
304        let momentum = if let Some(m) = param_state.get(&momentum_key) {
305            m.clone()
306        } else {
307            let m = if self.config.use_memory_pool {
308                self.memory_pool
309                    .get_tensor(param_write.shape().dims(), param_write.device())?
310            } else {
311                Tensor::zeros(param_write.shape().dims(), param_write.device())?
312            };
313            param_state.insert(momentum_key.clone(), m.clone());
314            m
315        };
316
317        let velocity = if let Some(v) = param_state.get(&velocity_key) {
318            v.clone()
319        } else {
320            let v = if self.config.use_memory_pool {
321                self.memory_pool
322                    .get_tensor(param_write.shape().dims(), param_write.device())?
323            } else {
324                Tensor::zeros(param_write.shape().dims(), param_write.device())?
325            };
326            param_state.insert(velocity_key.clone(), v.clone());
327            v
328        };
329
330        // Update biased first moment estimate
331        let new_momentum = momentum
332            .mul_scalar(self.beta1)?
333            .add(&effective_grad.mul_scalar(1.0 - self.beta1)?)?;
334
335        // Update biased second raw moment estimate
336        let grad_squared = effective_grad.mul_op(&effective_grad)?;
337        let new_velocity = velocity
338            .mul_scalar(self.beta2)?
339            .add(&grad_squared.mul_scalar(1.0 - self.beta2)?)?;
340
341        // Bias correction
342        let bias_correction1 = 1.0 - self.beta1.powi(self.step_count as i32);
343        let bias_correction2 = 1.0 - self.beta2.powi(self.step_count as i32);
344
345        let corrected_momentum = new_momentum.div_scalar(bias_correction1)?;
346        let corrected_velocity = new_velocity.div_scalar(bias_correction2)?;
347
348        // Handle AMSGrad variant - check memory first before borrowing param_state again
349        let needs_max_velocity_check =
350            self.amsgrad && !param_state.contains_key(&max_exp_avg_sq_key);
351
352        // Drop param_state temporarily to check memory if needed
353        if needs_max_velocity_check {
354            let _ = param_state;
355
356            if !self.can_allocate(param_memory) {
357                return Err(TorshError::invalid_argument_with_context(
358                    "Insufficient memory for max velocity buffer",
359                    "memory_efficient_adam_step",
360                ));
361            }
362
363            self.update_memory_usage(param_memory as isize);
364        }
365
366        // Re-acquire param_state for the rest of the function
367        let param_state = self.state.entry(param_id.clone()).or_default();
368
369        let exp_avg_sq_hat = if self.amsgrad {
370            let max_exp_avg_sq = if let Some(max_v) = param_state.get(&max_exp_avg_sq_key) {
371                max_v.clone()
372            } else {
373                let max_v = if self.config.use_memory_pool {
374                    self.memory_pool
375                        .get_tensor(param_write.shape().dims(), param_write.device())?
376                } else {
377                    Tensor::zeros(param_write.shape().dims(), param_write.device())?
378                };
379                param_state.insert(max_exp_avg_sq_key.clone(), max_v.clone());
380                max_v
381            };
382
383            let new_max_exp_avg_sq = max_exp_avg_sq.maximum(&corrected_velocity)?;
384            param_state.insert(max_exp_avg_sq_key, new_max_exp_avg_sq.clone());
385            new_max_exp_avg_sq
386        } else {
387            corrected_velocity.clone()
388        };
389
390        // Compute update
391        let denominator = exp_avg_sq_hat.sqrt()?.add_scalar(self.eps)?;
392        let update = corrected_momentum
393            .div(&denominator)?
394            .mul_scalar(-group_lr)?;
395
396        // Update parameters
397        *param_write = param_write.add(&update)?;
398
399        // Update state
400        param_state.insert(momentum_key, new_momentum);
401        param_state.insert(velocity_key, new_velocity);
402
403        // Compress state periodically
404        if self.step_count % self.config.checkpoint_interval == 0 {
405            self.maybe_compress_state(&param_id)?;
406        }
407
408        Ok(())
409    }
410}
411
412impl Optimizer for MemoryEfficientAdam {
413    fn step(&mut self) -> OptimizerResult<()> {
414        self.step_count += 1;
415
416        // Collect parameters and learning rates first to avoid borrowing issues
417        let param_data: Vec<(Arc<RwLock<Tensor>>, f32)> = self
418            .param_groups
419            .iter()
420            .flat_map(|group| {
421                let group_lr = group.lr;
422                group
423                    .params
424                    .iter()
425                    .map(move |param| (param.clone(), group_lr))
426            })
427            .collect();
428
429        for (param, group_lr) in param_data {
430            self.adam_step_memory_efficient(&param, group_lr)?;
431        }
432
433        Ok(())
434    }
435
436    fn zero_grad(&mut self) {
437        for group in &self.param_groups {
438            for param in &group.params {
439                param.write().zero_grad();
440            }
441        }
442    }
443
444    fn get_lr(&self) -> Vec<f32> {
445        self.param_groups.iter().map(|g| g.lr).collect()
446    }
447
448    fn set_lr(&mut self, lr: f32) {
449        for group in &mut self.param_groups {
450            group.lr = lr;
451        }
452    }
453
454    fn add_param_group(&mut self, params: Vec<Arc<RwLock<Tensor>>>, options: HashMap<String, f32>) {
455        let lr = options.get("lr").copied().unwrap_or(1e-3);
456        let group = ParamGroup::new(params, lr).with_options(options);
457        self.param_groups.push(group);
458    }
459
460    fn state_dict(&self) -> OptimizerResult<OptimizerState> {
461        let param_groups = self
462            .param_groups
463            .iter()
464            .map(|g| ParamGroupState {
465                lr: g.lr,
466                options: g.options.clone(),
467                param_count: g.params.len(),
468            })
469            .collect();
470
471        Ok(OptimizerState {
472            optimizer_type: "MemoryEfficientAdam".to_string(),
473            version: "0.1.0".to_string(),
474            param_groups,
475            state: self.state.clone(),
476            global_state: HashMap::new(),
477        })
478    }
479
480    fn load_state_dict(&mut self, state: OptimizerState) -> OptimizerResult<()> {
481        if state.param_groups.len() != self.param_groups.len() {
482            return Err(OptimizerError::InvalidParameter(
483                "Parameter group count mismatch".to_string(),
484            ));
485        }
486
487        for (i, group_state) in state.param_groups.iter().enumerate() {
488            self.param_groups[i].lr = group_state.lr;
489            self.param_groups[i].options = group_state.options.clone();
490        }
491
492        self.state = state.state;
493        Ok(())
494    }
495}
496
497/// Memory-efficient L-BFGS with improved history management
498pub struct MemoryEfficientLBFGS {
499    param_groups: Vec<ParamGroup>,
500    state: HashMap<String, HashMap<String, Tensor>>,
501    step_count: usize,
502
503    // L-BFGS parameters
504    max_iter: usize,
505    tolerance_grad: f32,
506    tolerance_change: f32,
507    history_size: usize,
508
509    // Memory optimization
510    memory_pool: MemoryPool,
511    config: MemoryConfig,
512    history_buffer: CircularBuffer<(Tensor, Tensor, f32)>, // (s, y, rho)
513}
514
515/// Circular buffer for L-BFGS history with memory management
516pub struct CircularBuffer<T> {
517    data: Vec<T>,
518    capacity: usize,
519    start: usize,
520    len: usize,
521}
522
523impl<T> CircularBuffer<T> {
524    pub fn new(capacity: usize) -> Self {
525        Self {
526            data: Vec::with_capacity(capacity),
527            capacity,
528            start: 0,
529            len: 0,
530        }
531    }
532
533    pub fn push(&mut self, item: T) {
534        if self.len < self.capacity {
535            self.data.push(item);
536            self.len += 1;
537        } else {
538            let index = (self.start + self.len) % self.capacity;
539            self.data[index] = item;
540            self.start = (self.start + 1) % self.capacity;
541        }
542    }
543
544    pub fn get(&self, index: usize) -> Option<&T> {
545        if index < self.len {
546            let actual_index = (self.start + index) % self.capacity;
547            self.data.get(actual_index)
548        } else {
549            None
550        }
551    }
552
553    pub fn len(&self) -> usize {
554        self.len
555    }
556
557    pub fn is_empty(&self) -> bool {
558        self.len == 0
559    }
560
561    pub fn clear(&mut self) {
562        self.data.clear();
563        self.start = 0;
564        self.len = 0;
565    }
566}
567
568impl MemoryEfficientLBFGS {
569    pub fn new(
570        params: Vec<Arc<RwLock<Tensor>>>,
571        lr: Option<f32>,
572        max_iter: Option<usize>,
573        tolerance_grad: Option<f32>,
574        tolerance_change: Option<f32>,
575        history_size: Option<usize>,
576        memory_config: Option<MemoryConfig>,
577    ) -> Self {
578        let lr = lr.unwrap_or(1.0);
579        let history_size = history_size.unwrap_or(10); // Reduced default for memory efficiency
580
581        let param_group = ParamGroup::new(params, lr);
582
583        Self {
584            param_groups: vec![param_group],
585            state: HashMap::new(),
586            step_count: 0,
587            max_iter: max_iter.unwrap_or(20),
588            tolerance_grad: tolerance_grad.unwrap_or(1e-7),
589            tolerance_change: tolerance_change.unwrap_or(1e-9),
590            history_size,
591            memory_pool: MemoryPool::new(),
592            config: memory_config.unwrap_or_default(),
593            history_buffer: CircularBuffer::new(history_size),
594        }
595    }
596
597    /// Get memory usage statistics
598    pub fn memory_stats(&self) -> HashMap<String, usize> {
599        let mut stats = HashMap::new();
600        stats.insert(
601            "total_usage".to_string(),
602            self.memory_pool.tensors.len() * std::mem::size_of::<Tensor>(),
603        );
604        stats.insert("history_size".to_string(), self.history_buffer.len());
605        stats.insert("pooled_tensors".to_string(), self.memory_pool.tensors.len());
606        stats
607    }
608
609    /// Clear memory pools and history to free up memory
610    pub fn clear_memory(&mut self) {
611        self.memory_pool.clear();
612        self.history_buffer.clear();
613        self.state.clear();
614    }
615}
616
617impl Optimizer for MemoryEfficientLBFGS {
618    fn step(&mut self) -> OptimizerResult<()> {
619        // Simplified memory-efficient L-BFGS implementation
620        // This is a placeholder that would need full implementation
621        self.step_count += 1;
622
623        // For now, just implement a simple gradient descent step
624        for group in &self.param_groups {
625            for param in &group.params {
626                let mut param_write = param.write();
627                if let Some(grad) = param_write.grad() {
628                    let update = grad.mul_scalar(-group.lr)?;
629                    *param_write = param_write.add(&update)?;
630                }
631            }
632        }
633
634        Ok(())
635    }
636
637    fn zero_grad(&mut self) {
638        for group in &self.param_groups {
639            for param in &group.params {
640                param.write().zero_grad();
641            }
642        }
643    }
644
645    fn get_lr(&self) -> Vec<f32> {
646        self.param_groups.iter().map(|g| g.lr).collect()
647    }
648
649    fn set_lr(&mut self, lr: f32) {
650        for group in &mut self.param_groups {
651            group.lr = lr;
652        }
653    }
654
655    fn add_param_group(&mut self, params: Vec<Arc<RwLock<Tensor>>>, options: HashMap<String, f32>) {
656        let lr = options.get("lr").copied().unwrap_or(1.0);
657        let group = ParamGroup::new(params, lr).with_options(options);
658        self.param_groups.push(group);
659    }
660
661    fn state_dict(&self) -> OptimizerResult<OptimizerState> {
662        let param_groups = self
663            .param_groups
664            .iter()
665            .map(|g| ParamGroupState {
666                lr: g.lr,
667                options: g.options.clone(),
668                param_count: g.params.len(),
669            })
670            .collect();
671
672        Ok(OptimizerState {
673            optimizer_type: "MemoryEfficientAdam".to_string(),
674            version: "0.1.0".to_string(),
675            param_groups,
676            state: self.state.clone(),
677            global_state: HashMap::new(),
678        })
679    }
680
681    fn load_state_dict(&mut self, state: OptimizerState) -> OptimizerResult<()> {
682        if state.param_groups.len() != self.param_groups.len() {
683            return Err(OptimizerError::InvalidParameter(
684                "Parameter group count mismatch".to_string(),
685            ));
686        }
687
688        for (i, group_state) in state.param_groups.iter().enumerate() {
689            self.param_groups[i].lr = group_state.lr;
690            self.param_groups[i].options = group_state.options.clone();
691        }
692
693        self.state = state.state;
694        Ok(())
695    }
696}
697
698/// Builder for memory-efficient optimizers
699pub struct MemoryEfficientOptimizerBuilder {
700    memory_config: MemoryConfig,
701}
702
703impl MemoryEfficientOptimizerBuilder {
704    pub fn new() -> Self {
705        Self {
706            memory_config: MemoryConfig::default(),
707        }
708    }
709
710    pub fn max_memory_gb(mut self, gb: f32) -> Self {
711        self.memory_config.max_memory_bytes = (gb * 1_000_000_000.0) as usize;
712        self
713    }
714
715    pub fn use_memory_pool(mut self, use_pool: bool) -> Self {
716        self.memory_config.use_memory_pool = use_pool;
717        self
718    }
719
720    pub fn compress_state(mut self, compress: bool) -> Self {
721        self.memory_config.compress_state = compress;
722        self
723    }
724
725    pub fn lazy_gradients(mut self, lazy: bool) -> Self {
726        self.memory_config.lazy_gradients = lazy;
727        self
728    }
729
730    pub fn checkpoint_interval(mut self, interval: usize) -> Self {
731        self.memory_config.checkpoint_interval = interval;
732        self
733    }
734
735    pub fn build_adam(self, params: Vec<Arc<RwLock<Tensor>>>, lr: f32) -> MemoryEfficientAdam {
736        MemoryEfficientAdam::new(
737            params,
738            lr,
739            None,
740            None,
741            None,
742            None,
743            None,
744            Some(self.memory_config),
745        )
746    }
747
748    pub fn build_lbfgs(self, params: Vec<Arc<RwLock<Tensor>>>, lr: f32) -> MemoryEfficientLBFGS {
749        MemoryEfficientLBFGS::new(
750            params,
751            Some(lr),
752            None,
753            None,
754            None,
755            None,
756            Some(self.memory_config),
757        )
758    }
759}
760
761impl Default for MemoryEfficientOptimizerBuilder {
762    fn default() -> Self {
763        Self::new()
764    }
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770    use torsh_tensor::creation::randn;
771
772    #[test]
773    fn test_memory_pool() -> OptimizerResult<()> {
774        let mut pool = MemoryPool::new();
775        let tensor = pool.get_tensor(&[2, 2], torsh_core::device::DeviceType::Cpu)?;
776        assert_eq!(tensor.shape().dims(), &[2, 2]);
777
778        pool.return_tensor(tensor);
779        let reused = pool.get_tensor(&[2, 2], torsh_core::device::DeviceType::Cpu)?;
780        assert_eq!(reused.shape().dims(), &[2, 2]);
781        Ok(())
782    }
783
784    #[test]
785    fn test_circular_buffer() {
786        let mut buffer = CircularBuffer::new(3);
787        buffer.push(1);
788        buffer.push(2);
789        buffer.push(3);
790        buffer.push(4); // Should overwrite 1
791
792        assert_eq!(buffer.len(), 3);
793        assert_eq!(buffer.get(0), Some(&2));
794        assert_eq!(buffer.get(1), Some(&3));
795        assert_eq!(buffer.get(2), Some(&4));
796    }
797
798    #[test]
799    fn test_memory_efficient_adam_creation() -> OptimizerResult<()> {
800        let params = vec![Arc::new(RwLock::new(randn::<f32>(&[2, 2])?))];
801        let optimizer = MemoryEfficientAdam::new(params, 0.001, None, None, None, None, None, None);
802        assert_eq!(optimizer.get_lr()[0], 0.001);
803        Ok(())
804    }
805
806    #[test]
807    fn test_memory_efficient_lbfgs_creation() -> OptimizerResult<()> {
808        let params = vec![Arc::new(RwLock::new(randn::<f32>(&[2, 2])?))];
809        let optimizer =
810            MemoryEfficientLBFGS::new(params, Some(0.1), None, None, None, Some(5), None);
811        assert_eq!(optimizer.get_lr()[0], 0.1);
812        assert_eq!(optimizer.history_size, 5);
813        Ok(())
814    }
815
816    #[test]
817    fn test_builder_pattern() -> OptimizerResult<()> {
818        let params = vec![Arc::new(RwLock::new(randn::<f32>(&[2, 2])?))];
819        let optimizer = MemoryEfficientOptimizerBuilder::new()
820            .max_memory_gb(1.0)
821            .use_memory_pool(true)
822            .compress_state(true)
823            .build_adam(params, 0.001);
824
825        assert_eq!(optimizer.get_lr()[0], 0.001);
826        assert_eq!(optimizer.config.max_memory_bytes, 1_000_000_000);
827        assert!(optimizer.config.use_memory_pool);
828        assert!(optimizer.config.compress_state);
829        Ok(())
830    }
831}