Skip to main content

trustformers_optim/
cache_friendly.rs

1//! Cache-friendly optimization algorithms for improved memory performance.
2//!
3//! This module implements optimization algorithms with cache-friendly memory access patterns,
4//! reducing cache misses and improving overall performance, especially for large models.
5//!
6//! # Key Optimizations
7//!
8//! - **Blocked/Tiled Operations**: Process data in cache-sized blocks
9//! - **Memory Layout Optimization**: Structure data for optimal cache utilization
10//! - **Data Prefetching**: Improve cache hit rates with strategic prefetching
11//! - **Loop Fusion**: Combine operations to reduce memory bandwidth requirements
12//! - **Vectorization-Friendly**: Design for SIMD instruction utilization
13
14use crate::common::{BiasCorrection, ParameterUpdate};
15use std::collections::HashMap;
16use trustformers_core::errors::{Result, TrustformersError};
17use trustformers_core::tensor::Tensor;
18use trustformers_core::traits::Optimizer;
19
20/// Cache configuration parameters for memory-aware optimizers.
21#[derive(Debug, Clone)]
22pub struct CacheConfig {
23    /// L1 cache size in bytes (typically 32KB)
24    pub l1_cache_size: usize,
25    /// L2 cache size in bytes (typically 256KB-1MB)
26    pub l2_cache_size: usize,
27    /// L3 cache size in bytes (typically 8MB-32MB)
28    pub l3_cache_size: usize,
29    /// Cache line size in bytes (typically 64 bytes)
30    pub cache_line_size: usize,
31    /// Block size for tiled operations
32    pub block_size: usize,
33    /// Whether to enable prefetching
34    pub enable_prefetching: bool,
35    /// Prefetch distance (number of cache lines ahead)
36    pub prefetch_distance: usize,
37}
38
39impl Default for CacheConfig {
40    fn default() -> Self {
41        Self {
42            l1_cache_size: 32 * 1024,       // 32KB
43            l2_cache_size: 256 * 1024,      // 256KB
44            l3_cache_size: 8 * 1024 * 1024, // 8MB
45            cache_line_size: 64,            // 64 bytes
46            block_size: 1024,               // Process 1024 elements at a time
47            enable_prefetching: true,
48            prefetch_distance: 4,
49        }
50    }
51}
52
53impl CacheConfig {
54    /// Detects cache configuration from the system (simplified version).
55    pub fn detect_system() -> Self {
56        // In a real implementation, this would use cpuid or similar
57        // to detect actual cache sizes
58        Self::default()
59    }
60
61    /// Configures for L1 cache optimization (small blocks, high frequency access).
62    pub fn l1_optimized() -> Self {
63        Self {
64            block_size: 512, // Smaller blocks for L1
65            ..Default::default()
66        }
67    }
68
69    /// Configures for L2 cache optimization (medium blocks).
70    pub fn l2_optimized() -> Self {
71        Self {
72            block_size: 2048, // Medium blocks for L2
73            ..Default::default()
74        }
75    }
76
77    /// Configures for L3 cache optimization (larger blocks).
78    pub fn l3_optimized() -> Self {
79        Self {
80            block_size: 8192, // Larger blocks for L3
81            ..Default::default()
82        }
83    }
84
85    /// Calculates optimal block size based on cache hierarchy.
86    pub fn optimal_block_size_for_arrays(&self, num_arrays: usize) -> usize {
87        // Account for multiple arrays (momentum, variance, parameters)
88        let available_cache = self.l2_cache_size / num_arrays;
89        let elements_per_cache = available_cache / std::mem::size_of::<f32>();
90
91        // Use power of 2 for better memory alignment
92        let mut block_size = 64;
93        while block_size * 2 <= elements_per_cache && block_size < 16384 {
94            block_size *= 2;
95        }
96
97        block_size.min(self.block_size)
98    }
99}
100
101/// Cache-friendly memory layout for optimizer state.
102///
103/// This structure organizes optimizer state data to maximize cache utilization
104/// by grouping frequently accessed data together.
105#[derive(Debug)]
106pub struct CacheFriendlyState {
107    /// Interleaved momentum and variance data for cache efficiency
108    /// Format: [momentum[i], variance[i], momentum[i+1], variance[i+1], ...]
109    pub interleaved_buffers: HashMap<usize, Vec<f32>>,
110    /// Parameter metadata for efficient access
111    pub param_metadata: HashMap<usize, ParameterMetadata>,
112    /// Current step counter
113    pub step: usize,
114    /// Cache configuration
115    pub cache_config: CacheConfig,
116}
117
118/// Metadata for efficient parameter processing.
119#[derive(Debug, Clone)]
120pub struct ParameterMetadata {
121    /// Starting offset in interleaved buffer
122    pub offset: usize,
123    /// Number of elements
124    pub size: usize,
125    /// Optimal block size for this parameter
126    pub block_size: usize,
127    /// Last access timestamp for cache management
128    pub last_access: usize,
129}
130
131impl CacheFriendlyState {
132    /// Creates a new cache-friendly state with the given configuration.
133    pub fn new(cache_config: CacheConfig) -> Self {
134        Self {
135            interleaved_buffers: HashMap::new(),
136            param_metadata: HashMap::new(),
137            step: 0,
138            cache_config,
139        }
140    }
141
142    /// Allocates buffers for a parameter with optimal memory layout.
143    pub fn allocate_parameter(&mut self, param_id: usize, size: usize) -> Result<()> {
144        // Allocate interleaved momentum and variance
145        // Format: [m0, v0, m1, v1, ..., mn, vn]
146        let buffer_size = size * 2; // momentum + variance
147        let buffer = vec![0.0; buffer_size];
148
149        let metadata = ParameterMetadata {
150            offset: 0,
151            size,
152            block_size: self.cache_config.optimal_block_size_for_arrays(3), // param, momentum, variance
153            last_access: self.step,
154        };
155
156        self.interleaved_buffers.insert(param_id, buffer);
157        self.param_metadata.insert(param_id, metadata);
158
159        Ok(())
160    }
161
162    /// Gets direct access to interleaved buffer for efficient in-place operations.
163    pub fn get_interleaved_buffer_mut(&mut self, param_id: usize) -> Option<(&mut [f32], usize)> {
164        if let (Some(buffer), Some(metadata)) = (
165            self.interleaved_buffers.get_mut(&param_id),
166            self.param_metadata.get_mut(&param_id),
167        ) {
168            metadata.last_access = self.step;
169            Some((buffer.as_mut_slice(), metadata.size))
170        } else {
171            None
172        }
173    }
174
175    /// Gets momentum and variance slices for a parameter (backward compatibility).
176    /// Note: This creates temporary vectors and is less efficient than get_interleaved_buffer_mut.
177    pub fn get_buffers_mut(&mut self, param_id: usize) -> Option<(Vec<f32>, Vec<f32>)> {
178        if let (Some(buffer), Some(metadata)) = (
179            self.interleaved_buffers.get(&param_id),
180            self.param_metadata.get_mut(&param_id),
181        ) {
182            metadata.last_access = self.step;
183
184            // Extract momentum and variance from interleaved buffer
185            let mut momentum = Vec::with_capacity(metadata.size);
186            let mut variance = Vec::with_capacity(metadata.size);
187
188            for i in 0..metadata.size {
189                momentum.push(buffer[i * 2]);
190                variance.push(buffer[i * 2 + 1]);
191            }
192
193            Some((momentum, variance))
194        } else {
195            None
196        }
197    }
198
199    /// Updates interleaved buffers with new momentum and variance values.
200    pub fn update_buffers(
201        &mut self,
202        param_id: usize,
203        momentum: &[f32],
204        variance: &[f32],
205    ) -> Result<()> {
206        if let Some(buffer) = self.interleaved_buffers.get_mut(&param_id) {
207            if momentum.len() != variance.len() || momentum.len() * 2 != buffer.len() {
208                return Err(TrustformersError::tensor_op_error(
209                    "Buffer size mismatch",
210                    "update_buffers",
211                ));
212            }
213
214            // Update interleaved buffer
215            for i in 0..momentum.len() {
216                buffer[i * 2] = momentum[i];
217                buffer[i * 2 + 1] = variance[i];
218            }
219
220            Ok(())
221        } else {
222            Err(TrustformersError::tensor_op_error(
223                "Parameter not found",
224                "update_buffers",
225            ))
226        }
227    }
228
229    /// Clears unused buffers to free memory.
230    pub fn garbage_collect(&mut self, access_threshold: usize) {
231        let current_step = self.step;
232        let stale_params: Vec<usize> = self
233            .param_metadata
234            .iter()
235            .filter(|(_, metadata)| current_step - metadata.last_access > access_threshold)
236            .map(|(id, _)| *id)
237            .collect();
238
239        for param_id in stale_params {
240            self.interleaved_buffers.remove(&param_id);
241            self.param_metadata.remove(&param_id);
242        }
243    }
244}
245
246/// Cache-friendly Adam optimizer implementation.
247///
248/// This optimizer uses blocked processing and optimized memory layouts
249/// to minimize cache misses and improve performance.
250#[derive(Debug)]
251pub struct CacheFriendlyAdam {
252    /// Learning rate
253    lr: f32,
254    /// Beta coefficients for momentum and variance
255    betas: (f32, f32),
256    /// Epsilon for numerical stability
257    eps: f32,
258    /// Weight decay coefficient
259    weight_decay: f32,
260    /// Cache-friendly state
261    state: CacheFriendlyState,
262}
263
264impl CacheFriendlyAdam {
265    /// Creates a new cache-friendly Adam optimizer.
266    pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
267        Self::with_cache_config(lr, betas, eps, weight_decay, CacheConfig::default())
268    }
269
270    /// Creates a cache-friendly Adam optimizer with custom cache configuration.
271    pub fn with_cache_config(
272        lr: f32,
273        betas: (f32, f32),
274        eps: f32,
275        weight_decay: f32,
276        cache_config: CacheConfig,
277    ) -> Self {
278        Self {
279            lr,
280            betas,
281            eps,
282            weight_decay,
283            state: CacheFriendlyState::new(cache_config),
284        }
285    }
286
287    /// Updates parameter using cache-friendly blocked processing (legacy wrapper).
288    #[allow(dead_code)]
289    fn update_parameter_blocked(
290        &mut self,
291        param: &mut [f32],
292        grad: &[f32],
293        param_id: String,
294    ) -> Result<()> {
295        // Convert string ID to numeric ID for compatibility
296        let numeric_id = param_id.as_ptr() as usize;
297        self.update_parameter_blocked_fast(param, grad, numeric_id)
298    }
299
300    /// Fast parameter update using numeric IDs to avoid string formatting overhead.
301    fn update_parameter_blocked_fast(
302        &mut self,
303        param: &mut [f32],
304        grad: &[f32],
305        param_id: usize,
306    ) -> Result<()> {
307        let size = param.len();
308        if grad.len() != size {
309            return Err(TrustformersError::tensor_op_error(
310                "Parameter and gradient size mismatch",
311                "update_parameter_blocked_fast",
312            ));
313        }
314
315        // Ensure parameter buffers are allocated with correct size
316        if !self.state.param_metadata.contains_key(&param_id) {
317            self.state.allocate_parameter(param_id, size)?;
318        } else {
319            // Check if size has changed and reallocate if needed
320            let current_size =
321                self.state.param_metadata.get(&param_id).map(|meta| meta.size).unwrap_or(0);
322            if current_size != size {
323                self.state.allocate_parameter(param_id, size)?;
324            }
325        }
326
327        // Extract needed values before borrowing the buffer
328        let step = self.state.step + 1;
329        let block_size = self
330            .state
331            .param_metadata
332            .get(&param_id)
333            .map(|meta| meta.block_size)
334            .unwrap_or(1024);
335        let _enable_prefetching = self.state.cache_config.enable_prefetching;
336
337        let (bias_correction1, bias_correction2) =
338            BiasCorrection::compute_adam_corrections(self.betas.0, self.betas.1, step);
339
340        // Get direct access to interleaved buffer for efficient operations
341        let (interleaved_buffer, _param_size) =
342            self.state.get_interleaved_buffer_mut(param_id).ok_or_else(|| {
343                TrustformersError::tensor_op_error(
344                    "Failed to get parameter buffers",
345                    "update_parameter_blocked_fast",
346                )
347            })?;
348
349        // For smaller tensors (< 4096 elements), use direct processing to avoid block overhead
350        if size < 4096 {
351            // Direct processing with inlined operations for better performance
352            for i in 0..size {
353                let grad_val = grad[i] + self.weight_decay * param[i];
354
355                // Work directly with interleaved buffer: [m0, v0, m1, v1, ...]
356                let momentum_idx = i * 2;
357                let variance_idx = i * 2 + 1;
358
359                // Update momentum and variance with inlined EMA operations
360                interleaved_buffer[momentum_idx] = self.betas.0 * interleaved_buffer[momentum_idx]
361                    + (1.0 - self.betas.0) * grad_val;
362                interleaved_buffer[variance_idx] = self.betas.1 * interleaved_buffer[variance_idx]
363                    + (1.0 - self.betas.1) * grad_val * grad_val;
364
365                // Apply bias-corrected update with inlined operations
366                let m_hat = interleaved_buffer[momentum_idx] / bias_correction1;
367                let v_hat = interleaved_buffer[variance_idx] / bias_correction2;
368
369                param[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
370            }
371        } else {
372            // Use block processing for larger tensors where cache benefits matter
373            let num_blocks = size.div_ceil(block_size);
374
375            for block_idx in 0..num_blocks {
376                let start = block_idx * block_size;
377                let end = (start + block_size).min(size);
378
379                // Note: Prefetching removed to avoid borrowing conflicts - this is a minor optimization
380                // Process current block with inlined operations
381                for i in start..end {
382                    let grad_val = grad[i] + self.weight_decay * param[i];
383
384                    // Work directly with interleaved buffer: [m0, v0, m1, v1, ...]
385                    let momentum_idx = i * 2;
386                    let variance_idx = i * 2 + 1;
387
388                    interleaved_buffer[momentum_idx] = self.betas.0
389                        * interleaved_buffer[momentum_idx]
390                        + (1.0 - self.betas.0) * grad_val;
391                    interleaved_buffer[variance_idx] = self.betas.1
392                        * interleaved_buffer[variance_idx]
393                        + (1.0 - self.betas.1) * grad_val * grad_val;
394
395                    let m_hat = interleaved_buffer[momentum_idx] / bias_correction1;
396                    let v_hat = interleaved_buffer[variance_idx] / bias_correction2;
397
398                    param[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
399                }
400            }
401        }
402
403        // No need to update buffers - we worked directly with the interleaved buffer
404        Ok(())
405    }
406
407    /// Processes a block with fused operations for better cache utilization.
408    #[inline]
409    #[allow(dead_code)]
410    fn process_block_fused(
411        &self,
412        param_block: &mut [f32],
413        grad_block: &[f32],
414        momentum_block: &mut [f32],
415        variance_block: &mut [f32],
416        bias_correction1: f32,
417        bias_correction2: f32,
418    ) {
419        // Fuse all operations for maximum cache efficiency
420        for i in 0..param_block.len() {
421            let grad_val = grad_block[i] + self.weight_decay * param_block[i];
422
423            // Update momentum and variance in one pass
424            ParameterUpdate::update_ema(&mut momentum_block[i], grad_val, self.betas.0);
425            ParameterUpdate::update_ema(&mut variance_block[i], grad_val * grad_val, self.betas.1);
426
427            // Apply bias-corrected update immediately
428            let m_hat = momentum_block[i] / bias_correction1;
429            let v_hat = variance_block[i] / bias_correction2;
430
431            ParameterUpdate::adam_update(&mut param_block[i], self.lr, m_hat, v_hat, self.eps);
432        }
433    }
434
435    /// Software prefetch hint for better cache performance.
436    #[inline]
437    #[allow(dead_code)]
438    fn prefetch_block(&self, block: &[f32]) {
439        // Implement cache-friendly prefetching for different architectures
440        if block.is_empty() {
441            return;
442        }
443
444        // Get pointer to first element
445        let ptr = block.as_ptr();
446
447        // Use architecture-specific prefetch instructions when available
448        #[cfg(target_arch = "x86_64")]
449        {
450            // Prefetch data into L1 cache (temporal locality)
451            // This provides a hint to the processor to load data into cache
452            // Use ptr.wrapping_add to avoid bounds checking in hot path
453            unsafe {
454                std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
455
456                // For larger blocks, prefetch multiple cache lines
457                if block.len() > 16 {
458                    // More than one cache line (64 bytes / 4 bytes per f32)
459                    let mid_ptr = ptr.wrapping_add(block.len() / 2);
460                    std::arch::x86_64::_mm_prefetch(
461                        mid_ptr as *const i8,
462                        std::arch::x86_64::_MM_HINT_T0,
463                    );
464                }
465            }
466        }
467
468        #[cfg(target_arch = "aarch64")]
469        {
470            // ARM64 prefetch using inline assembly
471            unsafe {
472                std::arch::asm!(
473                    "prfm pldl1keep, [{}]",
474                    in(reg) ptr,
475                    options(nostack, preserves_flags)
476                );
477            }
478        }
479
480        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
481        {
482            // Fallback: access first element to trigger cache load
483            // This is a software hint that usually gets optimized away
484            // but can help with cache warming on some architectures
485            let _ = unsafe { std::ptr::read_volatile(ptr) };
486        }
487    }
488
489    /// Gets cache utilization statistics.
490    pub fn cache_stats(&self) -> CacheStats {
491        let buffer_memory: usize = self
492            .state
493            .interleaved_buffers
494            .values()
495            .map(|buffer| buffer.len() * std::mem::size_of::<f32>())
496            .sum();
497
498        let num_params = self.state.param_metadata.len();
499        let total_elements: usize = self.state.param_metadata.values().map(|meta| meta.size).sum();
500
501        CacheStats {
502            buffer_memory_bytes: buffer_memory,
503            num_parameters: num_params,
504            total_elements,
505            cache_config: self.state.cache_config.clone(),
506            estimated_l1_utilization: self
507                .estimate_cache_utilization(buffer_memory, self.state.cache_config.l1_cache_size),
508            estimated_l2_utilization: self
509                .estimate_cache_utilization(buffer_memory, self.state.cache_config.l2_cache_size),
510        }
511    }
512
513    /// Estimates cache utilization percentage.
514    fn estimate_cache_utilization(&self, working_set_size: usize, cache_size: usize) -> f32 {
515        if cache_size == 0 {
516            return 1.0;
517        }
518        (working_set_size as f32 / cache_size as f32).min(1.0)
519    }
520
521    /// Performs garbage collection on unused parameters.
522    pub fn cleanup_unused_params(&mut self, steps_threshold: usize) {
523        self.state.garbage_collect(steps_threshold);
524    }
525}
526
527impl Optimizer for CacheFriendlyAdam {
528    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
529        match (parameter, grad) {
530            (Tensor::F32(param), Tensor::F32(grad_arr)) => {
531                // Use pointer address as numeric ID to avoid string formatting overhead
532                let param_id = param.as_ptr() as usize;
533                let param_slice = param.as_slice_mut().ok_or_else(|| {
534                    TrustformersError::tensor_op_error(
535                        "Parameter tensor is not contiguous",
536                        "update",
537                    )
538                })?;
539                let grad_slice = grad_arr.as_slice().ok_or_else(|| {
540                    TrustformersError::tensor_op_error(
541                        "Gradient tensor is not contiguous",
542                        "update",
543                    )
544                })?;
545                self.update_parameter_blocked_fast(param_slice, grad_slice, param_id)
546            },
547            _ => Err(TrustformersError::tensor_op_error(
548                "Unsupported tensor types for CacheFriendlyAdam",
549                "update",
550            )),
551        }
552    }
553
554    fn zero_grad(&mut self) {
555        // No explicit gradient storage
556    }
557
558    fn step(&mut self) {
559        self.state.step += 1;
560    }
561
562    fn get_lr(&self) -> f32 {
563        self.lr
564    }
565
566    fn set_lr(&mut self, lr: f32) {
567        self.lr = lr;
568    }
569}
570
571/// Cache performance statistics.
572#[derive(Debug, Clone)]
573pub struct CacheStats {
574    /// Total memory used by optimizer buffers
575    pub buffer_memory_bytes: usize,
576    /// Number of parameters being optimized
577    pub num_parameters: usize,
578    /// Total number of parameter elements
579    pub total_elements: usize,
580    /// Cache configuration used
581    pub cache_config: CacheConfig,
582    /// Estimated L1 cache utilization (0.0 to 1.0)
583    pub estimated_l1_utilization: f32,
584    /// Estimated L2 cache utilization (0.0 to 1.0)
585    pub estimated_l2_utilization: f32,
586}
587
588impl CacheStats {
589    /// Suggests optimization strategies based on cache utilization.
590    pub fn optimization_suggestions(&self) -> Vec<String> {
591        let mut suggestions = Vec::new();
592
593        if self.estimated_l1_utilization > 0.8 {
594            suggestions.push("Consider reducing block size for better L1 cache fit".to_string());
595        }
596
597        if self.estimated_l2_utilization > 0.9 {
598            suggestions
599                .push("Working set exceeds L2 cache; consider parameter partitioning".to_string());
600        }
601
602        if self.cache_config.block_size > 8192 {
603            suggestions.push("Large block size may cause cache thrashing".to_string());
604        }
605
606        if !self.cache_config.enable_prefetching {
607            suggestions.push("Enable prefetching for potential performance gains".to_string());
608        }
609
610        if suggestions.is_empty() {
611            suggestions.push("Cache utilization appears optimal".to_string());
612        }
613
614        suggestions
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    #[test]
623    fn test_cache_config_creation() {
624        let config = CacheConfig::default();
625        assert_eq!(config.l1_cache_size, 32 * 1024);
626        assert_eq!(config.cache_line_size, 64);
627        assert!(config.enable_prefetching);
628
629        let l1_config = CacheConfig::l1_optimized();
630        assert_eq!(l1_config.block_size, 512);
631    }
632
633    #[test]
634    fn test_optimal_block_size() {
635        let config = CacheConfig::default();
636        let block_size = config.optimal_block_size_for_arrays(3);
637        assert!(block_size > 0);
638        assert!(block_size <= config.block_size);
639        assert_eq!(block_size & (block_size - 1), 0); // Should be power of 2
640    }
641
642    #[test]
643    fn test_cache_friendly_state() {
644        let mut state = CacheFriendlyState::new(CacheConfig::default());
645
646        // Test parameter allocation
647        let param_id = 12345usize;
648        state.allocate_parameter(param_id, 100).unwrap();
649
650        assert!(state.param_metadata.contains_key(&param_id));
651        assert!(state.interleaved_buffers.contains_key(&param_id));
652
653        // Test buffer access
654        let (momentum, variance) = state.get_buffers_mut(param_id).unwrap();
655        assert_eq!(momentum.len(), 100);
656        assert_eq!(variance.len(), 100);
657    }
658
659    #[test]
660    fn test_cache_friendly_adam() {
661        let optimizer = CacheFriendlyAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
662        assert_eq!(optimizer.get_lr(), 1e-3);
663        assert_eq!(optimizer.betas, (0.9, 0.999));
664        assert_eq!(optimizer.eps, 1e-8);
665        assert_eq!(optimizer.weight_decay, 0.01);
666    }
667
668    #[test]
669    fn test_cache_stats() {
670        let optimizer = CacheFriendlyAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
671        let stats = optimizer.cache_stats();
672
673        assert_eq!(stats.num_parameters, 0);
674        assert_eq!(stats.total_elements, 0);
675        assert_eq!(stats.buffer_memory_bytes, 0);
676
677        let suggestions = stats.optimization_suggestions();
678        assert!(!suggestions.is_empty());
679    }
680
681    #[test]
682    fn test_cache_utilization_estimation() {
683        let optimizer = CacheFriendlyAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
684
685        let utilization = optimizer.estimate_cache_utilization(16 * 1024, 32 * 1024);
686        assert_eq!(utilization, 0.5);
687
688        let over_utilization = optimizer.estimate_cache_utilization(64 * 1024, 32 * 1024);
689        assert_eq!(over_utilization, 1.0);
690    }
691
692    #[test]
693    fn test_garbage_collection() {
694        let mut state = CacheFriendlyState::new(CacheConfig::default());
695
696        // Add some parameters
697        let param1_id = 11111usize;
698        let param2_id = 22222usize;
699        state.allocate_parameter(param1_id, 100).unwrap();
700        state.allocate_parameter(param2_id, 200).unwrap();
701
702        // Simulate time passing
703        state.step = 1000;
704
705        // Access only param1
706        state.get_buffers_mut(param1_id);
707
708        // Garbage collect with threshold of 10 steps
709        state.garbage_collect(10);
710
711        // param1 should remain, param2 should be removed
712        assert!(state.param_metadata.contains_key(&param1_id));
713        assert!(!state.param_metadata.contains_key(&param2_id));
714    }
715}