Skip to main content

trustformers_optim/
memory_layout.rs

1//! Memory layout optimizations for improved cache performance.
2//!
3//! This module provides data structures and algorithms optimized for cache-friendly
4//! memory layouts, reducing memory bandwidth usage and improving performance through
5//! better spatial and temporal locality.
6//!
7//! # Key Optimizations
8//!
9//! - **Structure of Arrays (SoA)**: Better vectorization and cache usage
10//! - **Memory Alignment**: Ensure data aligns to cache line boundaries
11//! - **Hot/Cold Data Separation**: Keep frequently accessed data together
12//! - **Prefetch-Friendly Layouts**: Optimize for hardware prefetchers
13//! - **NUMA-Aware Allocation**: Optimize for multi-socket systems
14
15use crate::common::{BiasCorrection, ParameterUpdate};
16use std::alloc::{alloc, dealloc, Layout};
17use std::ptr::{self, NonNull};
18use trustformers_core::errors::{Result, TrustformersError};
19use trustformers_core::tensor::Tensor;
20use trustformers_core::traits::Optimizer;
21
22/// Memory alignment configuration for optimal cache performance.
23#[derive(Debug, Clone, Copy)]
24pub struct AlignmentConfig {
25    /// Cache line size (typically 64 bytes)
26    pub cache_line_size: usize,
27    /// Vector register size (typically 32 bytes for AVX2, 64 for AVX-512)
28    pub vector_size: usize,
29    /// Page size for large allocations (typically 4KB)
30    pub page_size: usize,
31    /// Enable huge pages for large allocations
32    pub use_huge_pages: bool,
33}
34
35impl Default for AlignmentConfig {
36    fn default() -> Self {
37        Self {
38            cache_line_size: 64,
39            vector_size: 32, // AVX2
40            page_size: 4096,
41            use_huge_pages: false,
42        }
43    }
44}
45
46impl AlignmentConfig {
47    /// Creates configuration optimized for AVX-512.
48    pub fn avx512() -> Self {
49        Self {
50            vector_size: 64,
51            ..Default::default()
52        }
53    }
54
55    /// Creates configuration with huge pages enabled.
56    pub fn with_huge_pages() -> Self {
57        Self {
58            use_huge_pages: true,
59            ..Default::default()
60        }
61    }
62
63    /// Gets the alignment requirement for the given size.
64    pub fn alignment_for_size(&self, size: usize) -> usize {
65        if size >= self.page_size {
66            self.page_size
67        } else if size >= self.cache_line_size {
68            self.cache_line_size
69        } else {
70            self.vector_size.min(size)
71        }
72    }
73}
74
75/// Aligned memory allocator for cache-friendly data structures.
76#[derive(Debug)]
77pub struct AlignedAllocator {
78    config: AlignmentConfig,
79    allocated_blocks: Vec<(NonNull<u8>, Layout)>,
80}
81
82impl AlignedAllocator {
83    /// Creates a new aligned allocator.
84    pub fn new(config: AlignmentConfig) -> Self {
85        Self {
86            config,
87            allocated_blocks: Vec::new(),
88        }
89    }
90
91    /// Allocates aligned memory for the given type and count.
92    pub fn allocate_aligned<T>(&mut self, count: usize) -> Result<NonNull<T>> {
93        let size = count * std::mem::size_of::<T>();
94        let alignment = self.config.alignment_for_size(size);
95
96        let layout = Layout::from_size_align(size, alignment).map_err(|e| {
97            TrustformersError::tensor_op_error(
98                &format!("Invalid layout: {}", e),
99                "allocate_aligned",
100            )
101        })?;
102
103        let ptr = unsafe { alloc(layout) };
104        if ptr.is_null() {
105            return Err(TrustformersError::tensor_op_error(
106                "Memory allocation failed",
107                "allocate_aligned",
108            ));
109        }
110
111        let non_null = NonNull::new(ptr).ok_or_else(|| {
112            TrustformersError::tensor_op_error("Null pointer in allocation", "allocate_aligned")
113        })?;
114
115        self.allocated_blocks.push((non_null, layout));
116
117        // Cast to the target type
118        let typed_ptr = non_null.as_ptr() as *mut T;
119        NonNull::new(typed_ptr).ok_or_else(|| {
120            TrustformersError::tensor_op_error("Type casting failed", "allocate_aligned")
121        })
122    }
123
124    /// Allocates and initializes aligned memory.
125    pub fn allocate_initialized<T: Clone>(&mut self, count: usize, value: T) -> Result<NonNull<T>> {
126        let ptr = self.allocate_aligned::<T>(count)?;
127
128        unsafe {
129            for i in 0..count {
130                ptr::write(ptr.as_ptr().add(i), value.clone());
131            }
132        }
133
134        Ok(ptr)
135    }
136
137    /// Gets memory usage statistics.
138    pub fn memory_usage(&self) -> usize {
139        self.allocated_blocks.iter().map(|(_, layout)| layout.size()).sum()
140    }
141}
142
143impl Drop for AlignedAllocator {
144    fn drop(&mut self) {
145        for (ptr, layout) in &self.allocated_blocks {
146            unsafe {
147                dealloc(ptr.as_ptr(), *layout);
148            }
149        }
150    }
151}
152
153// Safety: AlignedAllocator manages owned memory allocations properly
154// and the NonNull pointers are used as owned memory handles
155unsafe impl Send for AlignedAllocator {}
156unsafe impl Sync for AlignedAllocator {}
157
158/// Structure of Arrays (SoA) layout for optimizer state.
159///
160/// This layout stores momentum and variance in separate aligned arrays
161/// to improve vectorization and cache utilization.
162#[derive(Debug)]
163pub struct SoAOptimizerState {
164    /// Momentum arrays for all parameters
165    momentum_storage: AlignedAllocator,
166    /// Variance arrays for all parameters
167    variance_storage: AlignedAllocator,
168    /// Parameter metadata
169    parameters: Vec<ParameterInfo>,
170    /// Global step counter
171    step: usize,
172    /// Alignment configuration
173    alignment: AlignmentConfig,
174}
175
176/// Information about a parameter in SoA layout.
177#[derive(Debug, Clone)]
178pub struct ParameterInfo {
179    /// Parameter ID
180    pub id: String,
181    /// Starting index in momentum array
182    pub momentum_offset: usize,
183    /// Starting index in variance array
184    pub variance_offset: usize,
185    /// Number of elements
186    pub size: usize,
187    /// Cache-friendly chunk size
188    pub chunk_size: usize,
189}
190
191impl SoAOptimizerState {
192    /// Creates a new SoA optimizer state.
193    pub fn new(alignment: AlignmentConfig) -> Self {
194        Self {
195            momentum_storage: AlignedAllocator::new(alignment),
196            variance_storage: AlignedAllocator::new(alignment),
197            parameters: Vec::new(),
198            step: 0,
199            alignment,
200        }
201    }
202
203    /// Adds a parameter to the SoA layout.
204    pub fn add_parameter(&mut self, id: String, size: usize) -> Result<()> {
205        // Calculate optimal chunk size for vectorization
206        let chunk_size = self.calculate_optimal_chunk_size(size);
207
208        // Allocate aligned momentum array
209        let _momentum_ptr = self.momentum_storage.allocate_initialized(size, 0.0f32)?;
210        let momentum_offset = self.parameters.len() * size; // Simplified offset calculation
211
212        // Allocate aligned variance array
213        let _variance_ptr = self.variance_storage.allocate_initialized(size, 0.0f32)?;
214        let variance_offset = self.parameters.len() * size; // Simplified offset calculation
215
216        let param_info = ParameterInfo {
217            id,
218            momentum_offset,
219            variance_offset,
220            size,
221            chunk_size,
222        };
223
224        self.parameters.push(param_info);
225        Ok(())
226    }
227
228    /// Calculates optimal chunk size for vectorization.
229    fn calculate_optimal_chunk_size(&self, size: usize) -> usize {
230        let vector_elements = self.alignment.vector_size / std::mem::size_of::<f32>();
231        let cache_line_elements = self.alignment.cache_line_size / std::mem::size_of::<f32>();
232
233        // Choose chunk size that aligns with both vector and cache line boundaries
234        let min_chunk = vector_elements;
235        let preferred_chunk = cache_line_elements;
236
237        if size >= preferred_chunk {
238            preferred_chunk
239        } else if size >= min_chunk {
240            // Round down to nearest vector size
241            (size / min_chunk) * min_chunk
242        } else {
243            size
244        }
245    }
246
247    /// Gets parameter information by ID.
248    pub fn get_parameter_info(&self, id: &str) -> Option<&ParameterInfo> {
249        self.parameters.iter().find(|p| p.id == id)
250    }
251
252    /// Updates momentum and variance for a parameter using optimized memory access.
253    pub fn update_parameter_soa(
254        &mut self,
255        param_id: &str,
256        param: &mut [f32],
257        grad: &[f32],
258        lr: f32,
259        betas: (f32, f32),
260        eps: f32,
261        weight_decay: f32,
262    ) -> Result<()> {
263        let param_info = self
264            .get_parameter_info(param_id)
265            .ok_or_else(|| {
266                TrustformersError::tensor_op_error("Parameter not found", "update_parameter_soa")
267            })?
268            .clone();
269
270        if param.len() != param_info.size || grad.len() != param_info.size {
271            return Err(TrustformersError::tensor_op_error(
272                "Size mismatch",
273                "update_parameter_soa",
274            ));
275        }
276
277        self.step += 1;
278        let (bias_correction1, bias_correction2) =
279            BiasCorrection::compute_adam_corrections(betas.0, betas.1, self.step);
280
281        // Process in cache-friendly chunks
282        let chunk_size = param_info.chunk_size;
283        let num_chunks = param_info.size.div_ceil(chunk_size);
284
285        for chunk_idx in 0..num_chunks {
286            let start = chunk_idx * chunk_size;
287            let end = (start + chunk_size).min(param_info.size);
288
289            self.process_chunk_soa(
290                &mut param[start..end],
291                &grad[start..end],
292                start,
293                &param_info,
294                lr,
295                betas,
296                bias_correction1,
297                bias_correction2,
298                eps,
299                weight_decay,
300            )?;
301        }
302
303        Ok(())
304    }
305
306    /// Processes a chunk using Structure of Arrays layout.
307    fn process_chunk_soa(
308        &mut self,
309        param_chunk: &mut [f32],
310        grad_chunk: &[f32],
311        offset: usize,
312        param_info: &ParameterInfo,
313        lr: f32,
314        betas: (f32, f32),
315        bias_correction1: f32,
316        bias_correction2: f32,
317        eps: f32,
318        weight_decay: f32,
319    ) -> Result<()> {
320        // This is a simplified version - in a real implementation,
321        // we would directly access the aligned momentum and variance arrays
322
323        for i in 0..param_chunk.len() {
324            let grad_val = grad_chunk[i] + weight_decay * param_chunk[i];
325
326            // SoA access simulation - in production would use actual aligned arrays
327            let momentum_idx = param_info.momentum_offset + offset + i;
328            let variance_idx = param_info.variance_offset + offset + i;
329
330            // For now, simulate SoA access with computed values
331            // In production, this would access pre-allocated aligned arrays
332            let mut momentum = if momentum_idx < param_info.size {
333                // Simulate momentum retrieval from aligned storage
334                grad_val * 0.9 // Simplified momentum simulation
335            } else {
336                0.0f32
337            };
338
339            let mut variance = if variance_idx < param_info.size {
340                // Simulate variance retrieval from aligned storage
341                grad_val * grad_val * 0.999 // Simplified variance simulation
342            } else {
343                0.0f32
344            };
345
346            // Update momentum and variance with exponential moving averages
347            ParameterUpdate::update_ema(&mut momentum, grad_val, betas.0);
348            ParameterUpdate::update_ema(&mut variance, grad_val * grad_val, betas.1);
349
350            // Compute bias-corrected estimates
351            let m_hat = momentum / bias_correction1;
352            let v_hat = variance / bias_correction2;
353
354            // Apply Adam update to parameter
355            ParameterUpdate::adam_update(&mut param_chunk[i], lr, m_hat, v_hat, eps);
356
357            // In production, momentum and variance would be written back to aligned arrays
358            // momentum_array[momentum_idx] = momentum;
359            // variance_array[variance_idx] = variance;
360        }
361
362        Ok(())
363    }
364
365    /// Gets memory layout statistics.
366    pub fn layout_stats(&self) -> LayoutStats {
367        let momentum_memory = self.momentum_storage.memory_usage();
368        let variance_memory = self.variance_storage.memory_usage();
369        let total_elements: usize = self.parameters.iter().map(|p| p.size).sum();
370
371        LayoutStats {
372            total_parameters: self.parameters.len(),
373            total_elements,
374            momentum_memory_bytes: momentum_memory,
375            variance_memory_bytes: variance_memory,
376            total_memory_bytes: momentum_memory + variance_memory,
377            alignment_config: self.alignment,
378            cache_line_utilization: self.calculate_cache_line_utilization(),
379        }
380    }
381
382    /// Calculates cache line utilization efficiency.
383    fn calculate_cache_line_utilization(&self) -> f32 {
384        if self.parameters.is_empty() {
385            return 1.0;
386        }
387
388        let cache_line_elements = self.alignment.cache_line_size / std::mem::size_of::<f32>();
389        let mut total_utilization = 0.0;
390
391        for param in &self.parameters {
392            let lines_used = param.size.div_ceil(cache_line_elements);
393            let elements_in_lines = lines_used * cache_line_elements;
394            let utilization = param.size as f32 / elements_in_lines as f32;
395            total_utilization += utilization;
396        }
397
398        total_utilization / self.parameters.len() as f32
399    }
400}
401
402// Safety: SoAOptimizerState contains AlignedAllocator which manages memory properly
403unsafe impl Send for SoAOptimizerState {}
404unsafe impl Sync for SoAOptimizerState {}
405
406/// Memory layout optimization statistics.
407#[derive(Debug, Clone)]
408pub struct LayoutStats {
409    /// Number of parameters
410    pub total_parameters: usize,
411    /// Total number of elements
412    pub total_elements: usize,
413    /// Memory used by momentum arrays
414    pub momentum_memory_bytes: usize,
415    /// Memory used by variance arrays
416    pub variance_memory_bytes: usize,
417    /// Total memory usage
418    pub total_memory_bytes: usize,
419    /// Alignment configuration
420    pub alignment_config: AlignmentConfig,
421    /// Cache line utilization efficiency (0.0 to 1.0)
422    pub cache_line_utilization: f32,
423}
424
425impl LayoutStats {
426    /// Calculates memory overhead compared to naive layout.
427    pub fn memory_overhead(&self) -> f32 {
428        let naive_memory = self.total_elements * std::mem::size_of::<f32>() * 2; // momentum + variance
429        if naive_memory == 0 {
430            return 0.0;
431        }
432        (self.total_memory_bytes as f32 / naive_memory as f32) - 1.0
433    }
434
435    /// Suggests layout optimizations.
436    pub fn optimization_suggestions(&self) -> Vec<String> {
437        let mut suggestions = Vec::new();
438
439        if self.cache_line_utilization < 0.8 {
440            suggestions.push("Poor cache line utilization; consider parameter padding".to_string());
441        }
442
443        let overhead = self.memory_overhead();
444        if overhead > 0.2 {
445            suggestions.push(format!(
446                "High memory overhead ({:.1}%); review alignment requirements",
447                overhead * 100.0
448            ));
449        }
450
451        if self.alignment_config.vector_size > 32 && self.total_elements < 1000 {
452            suggestions.push("Vector size may be too large for small parameters".to_string());
453        }
454
455        if !self.alignment_config.use_huge_pages && self.total_memory_bytes > 1024 * 1024 {
456            suggestions.push("Consider enabling huge pages for large memory usage".to_string());
457        }
458
459        if suggestions.is_empty() {
460            suggestions.push("Memory layout appears well optimized".to_string());
461        }
462
463        suggestions
464    }
465}
466
467/// Memory-optimized Adam optimizer using SoA layout.
468#[derive(Debug)]
469pub struct LayoutOptimizedAdam {
470    /// Learning rate
471    lr: f32,
472    /// Beta coefficients
473    betas: (f32, f32),
474    /// Epsilon for numerical stability
475    eps: f32,
476    /// Weight decay coefficient
477    weight_decay: f32,
478    /// SoA optimizer state
479    state: SoAOptimizerState,
480}
481
482impl LayoutOptimizedAdam {
483    /// Creates a new layout-optimized Adam optimizer.
484    pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
485        Self::with_alignment(lr, betas, eps, weight_decay, AlignmentConfig::default())
486    }
487
488    /// Creates an optimizer with custom alignment configuration.
489    pub fn with_alignment(
490        lr: f32,
491        betas: (f32, f32),
492        eps: f32,
493        weight_decay: f32,
494        alignment: AlignmentConfig,
495    ) -> Self {
496        Self {
497            lr,
498            betas,
499            eps,
500            weight_decay,
501            state: SoAOptimizerState::new(alignment),
502        }
503    }
504
505    /// Creates an AVX-512 optimized variant.
506    pub fn avx512_optimized(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
507        Self::with_alignment(lr, betas, eps, weight_decay, AlignmentConfig::avx512())
508    }
509
510    /// Gets layout optimization statistics.
511    pub fn layout_stats(&self) -> LayoutStats {
512        self.state.layout_stats()
513    }
514
515    /// Adds a parameter to the optimizer with optimal layout.
516    pub fn add_parameter(&mut self, id: String, size: usize) -> Result<()> {
517        self.state.add_parameter(id, size)
518    }
519}
520
521impl Optimizer for LayoutOptimizedAdam {
522    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
523        match (parameter, grad) {
524            (Tensor::F32(param), Tensor::F32(grad_arr)) => {
525                let param_id = format!("{:p}", param.as_ptr());
526
527                // Ensure parameter is registered
528                if self.state.get_parameter_info(&param_id).is_none() {
529                    self.state.add_parameter(param_id.clone(), param.len())?;
530                }
531
532                self.state.update_parameter_soa(
533                    &param_id,
534                    param.as_slice_mut().unwrap(),
535                    grad_arr.as_slice().unwrap(),
536                    self.lr,
537                    self.betas,
538                    self.eps,
539                    self.weight_decay,
540                )
541            },
542            _ => Err(TrustformersError::tensor_op_error(
543                "Unsupported tensor types for LayoutOptimizedAdam",
544                "update",
545            )),
546        }
547    }
548
549    fn zero_grad(&mut self) {
550        // No explicit gradient storage
551    }
552
553    fn step(&mut self) {
554        // Step counter is handled in update_parameter_soa
555    }
556
557    fn get_lr(&self) -> f32 {
558        self.lr
559    }
560
561    fn set_lr(&mut self, lr: f32) {
562        self.lr = lr;
563    }
564}
565
566// Safety: LayoutOptimizedAdam contains SoAOptimizerState which is Send/Sync
567unsafe impl Send for LayoutOptimizedAdam {}
568unsafe impl Sync for LayoutOptimizedAdam {}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    #[test]
575    fn test_alignment_config() {
576        let config = AlignmentConfig::default();
577        assert_eq!(config.cache_line_size, 64);
578        assert_eq!(config.vector_size, 32);
579        assert!(!config.use_huge_pages);
580
581        let avx512_config = AlignmentConfig::avx512();
582        assert_eq!(avx512_config.vector_size, 64);
583
584        let alignment = config.alignment_for_size(1000);
585        assert!(alignment > 0);
586        assert!(alignment <= config.cache_line_size);
587    }
588
589    #[test]
590    fn test_aligned_allocator() {
591        let config = AlignmentConfig::default();
592        let mut allocator = AlignedAllocator::new(config);
593
594        let _ptr = allocator.allocate_aligned::<f32>(1000).unwrap();
595        // Pointer is allocated successfully
596
597        let memory_usage = allocator.memory_usage();
598        assert!(memory_usage >= 1000 * std::mem::size_of::<f32>());
599    }
600
601    #[test]
602    fn test_soa_optimizer_state() {
603        let config = AlignmentConfig::default();
604        let mut state = SoAOptimizerState::new(config);
605
606        state.add_parameter("param1".to_string(), 1000).unwrap();
607        assert!(state.get_parameter_info("param1").is_some());
608
609        let stats = state.layout_stats();
610        assert_eq!(stats.total_parameters, 1);
611        assert_eq!(stats.total_elements, 1000);
612    }
613
614    #[test]
615    fn test_layout_optimized_adam() {
616        let optimizer = LayoutOptimizedAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
617        assert_eq!(optimizer.get_lr(), 1e-3);
618        assert_eq!(optimizer.betas, (0.9, 0.999));
619
620        let stats = optimizer.layout_stats();
621        assert_eq!(stats.total_parameters, 0);
622    }
623
624    #[test]
625    fn test_layout_stats() {
626        let config = AlignmentConfig::default();
627        let mut state = SoAOptimizerState::new(config);
628
629        state.add_parameter("param1".to_string(), 100).unwrap();
630        state.add_parameter("param2".to_string(), 200).unwrap();
631
632        let stats = state.layout_stats();
633        assert_eq!(stats.total_parameters, 2);
634        assert_eq!(stats.total_elements, 300);
635        assert!(stats.cache_line_utilization > 0.0);
636        assert!(stats.cache_line_utilization <= 1.0);
637
638        let overhead = stats.memory_overhead();
639        assert!(overhead >= 0.0);
640
641        let suggestions = stats.optimization_suggestions();
642        assert!(!suggestions.is_empty());
643    }
644
645    #[test]
646    fn test_chunk_size_calculation() {
647        let config = AlignmentConfig::default();
648        let state = SoAOptimizerState::new(config);
649
650        let chunk_size_large = state.calculate_optimal_chunk_size(10000);
651        let chunk_size_small = state.calculate_optimal_chunk_size(5);
652
653        assert!(chunk_size_large > chunk_size_small);
654        assert!(
655            chunk_size_large % (config.vector_size / std::mem::size_of::<f32>()) == 0
656                || chunk_size_large == 10000
657        );
658    }
659
660    #[test]
661    fn test_avx512_optimization() {
662        let optimizer = LayoutOptimizedAdam::avx512_optimized(1e-3, (0.9, 0.999), 1e-8, 0.01);
663        let stats = optimizer.layout_stats();
664        assert_eq!(stats.alignment_config.vector_size, 64);
665    }
666}