1use crate::common::{BiasCorrection, ParameterUpdate};
16use std::alloc::{alloc, dealloc, Layout};
17use std::ptr::{self, NonNull};
18use trustformers_core::errors::{Result, TrustformersError};
19use trustformers_core::tensor::Tensor;
20use trustformers_core::traits::Optimizer;
21
22#[derive(Debug, Clone, Copy)]
24pub struct AlignmentConfig {
25 pub cache_line_size: usize,
27 pub vector_size: usize,
29 pub page_size: usize,
31 pub use_huge_pages: bool,
33}
34
35impl Default for AlignmentConfig {
36 fn default() -> Self {
37 Self {
38 cache_line_size: 64,
39 vector_size: 32, page_size: 4096,
41 use_huge_pages: false,
42 }
43 }
44}
45
46impl AlignmentConfig {
47 pub fn avx512() -> Self {
49 Self {
50 vector_size: 64,
51 ..Default::default()
52 }
53 }
54
55 pub fn with_huge_pages() -> Self {
57 Self {
58 use_huge_pages: true,
59 ..Default::default()
60 }
61 }
62
63 pub fn alignment_for_size(&self, size: usize) -> usize {
65 if size >= self.page_size {
66 self.page_size
67 } else if size >= self.cache_line_size {
68 self.cache_line_size
69 } else {
70 self.vector_size.min(size)
71 }
72 }
73}
74
75#[derive(Debug)]
77pub struct AlignedAllocator {
78 config: AlignmentConfig,
79 allocated_blocks: Vec<(NonNull<u8>, Layout)>,
80}
81
82impl AlignedAllocator {
83 pub fn new(config: AlignmentConfig) -> Self {
85 Self {
86 config,
87 allocated_blocks: Vec::new(),
88 }
89 }
90
91 pub fn allocate_aligned<T>(&mut self, count: usize) -> Result<NonNull<T>> {
93 let size = count * std::mem::size_of::<T>();
94 let alignment = self.config.alignment_for_size(size);
95
96 let layout = Layout::from_size_align(size, alignment).map_err(|e| {
97 TrustformersError::tensor_op_error(
98 &format!("Invalid layout: {}", e),
99 "allocate_aligned",
100 )
101 })?;
102
103 let ptr = unsafe { alloc(layout) };
104 if ptr.is_null() {
105 return Err(TrustformersError::tensor_op_error(
106 "Memory allocation failed",
107 "allocate_aligned",
108 ));
109 }
110
111 let non_null = NonNull::new(ptr).ok_or_else(|| {
112 TrustformersError::tensor_op_error("Null pointer in allocation", "allocate_aligned")
113 })?;
114
115 self.allocated_blocks.push((non_null, layout));
116
117 let typed_ptr = non_null.as_ptr() as *mut T;
119 NonNull::new(typed_ptr).ok_or_else(|| {
120 TrustformersError::tensor_op_error("Type casting failed", "allocate_aligned")
121 })
122 }
123
124 pub fn allocate_initialized<T: Clone>(&mut self, count: usize, value: T) -> Result<NonNull<T>> {
126 let ptr = self.allocate_aligned::<T>(count)?;
127
128 unsafe {
129 for i in 0..count {
130 ptr::write(ptr.as_ptr().add(i), value.clone());
131 }
132 }
133
134 Ok(ptr)
135 }
136
137 pub fn memory_usage(&self) -> usize {
139 self.allocated_blocks.iter().map(|(_, layout)| layout.size()).sum()
140 }
141}
142
143impl Drop for AlignedAllocator {
144 fn drop(&mut self) {
145 for (ptr, layout) in &self.allocated_blocks {
146 unsafe {
147 dealloc(ptr.as_ptr(), *layout);
148 }
149 }
150 }
151}
152
153unsafe impl Send for AlignedAllocator {}
156unsafe impl Sync for AlignedAllocator {}
157
158#[derive(Debug)]
163pub struct SoAOptimizerState {
164 momentum_storage: AlignedAllocator,
166 variance_storage: AlignedAllocator,
168 parameters: Vec<ParameterInfo>,
170 step: usize,
172 alignment: AlignmentConfig,
174}
175
176#[derive(Debug, Clone)]
178pub struct ParameterInfo {
179 pub id: String,
181 pub momentum_offset: usize,
183 pub variance_offset: usize,
185 pub size: usize,
187 pub chunk_size: usize,
189}
190
191impl SoAOptimizerState {
192 pub fn new(alignment: AlignmentConfig) -> Self {
194 Self {
195 momentum_storage: AlignedAllocator::new(alignment),
196 variance_storage: AlignedAllocator::new(alignment),
197 parameters: Vec::new(),
198 step: 0,
199 alignment,
200 }
201 }
202
203 pub fn add_parameter(&mut self, id: String, size: usize) -> Result<()> {
205 let chunk_size = self.calculate_optimal_chunk_size(size);
207
208 let _momentum_ptr = self.momentum_storage.allocate_initialized(size, 0.0f32)?;
210 let momentum_offset = self.parameters.len() * size; let _variance_ptr = self.variance_storage.allocate_initialized(size, 0.0f32)?;
214 let variance_offset = self.parameters.len() * size; let param_info = ParameterInfo {
217 id,
218 momentum_offset,
219 variance_offset,
220 size,
221 chunk_size,
222 };
223
224 self.parameters.push(param_info);
225 Ok(())
226 }
227
228 fn calculate_optimal_chunk_size(&self, size: usize) -> usize {
230 let vector_elements = self.alignment.vector_size / std::mem::size_of::<f32>();
231 let cache_line_elements = self.alignment.cache_line_size / std::mem::size_of::<f32>();
232
233 let min_chunk = vector_elements;
235 let preferred_chunk = cache_line_elements;
236
237 if size >= preferred_chunk {
238 preferred_chunk
239 } else if size >= min_chunk {
240 (size / min_chunk) * min_chunk
242 } else {
243 size
244 }
245 }
246
247 pub fn get_parameter_info(&self, id: &str) -> Option<&ParameterInfo> {
249 self.parameters.iter().find(|p| p.id == id)
250 }
251
252 pub fn update_parameter_soa(
254 &mut self,
255 param_id: &str,
256 param: &mut [f32],
257 grad: &[f32],
258 lr: f32,
259 betas: (f32, f32),
260 eps: f32,
261 weight_decay: f32,
262 ) -> Result<()> {
263 let param_info = self
264 .get_parameter_info(param_id)
265 .ok_or_else(|| {
266 TrustformersError::tensor_op_error("Parameter not found", "update_parameter_soa")
267 })?
268 .clone();
269
270 if param.len() != param_info.size || grad.len() != param_info.size {
271 return Err(TrustformersError::tensor_op_error(
272 "Size mismatch",
273 "update_parameter_soa",
274 ));
275 }
276
277 self.step += 1;
278 let (bias_correction1, bias_correction2) =
279 BiasCorrection::compute_adam_corrections(betas.0, betas.1, self.step);
280
281 let chunk_size = param_info.chunk_size;
283 let num_chunks = param_info.size.div_ceil(chunk_size);
284
285 for chunk_idx in 0..num_chunks {
286 let start = chunk_idx * chunk_size;
287 let end = (start + chunk_size).min(param_info.size);
288
289 self.process_chunk_soa(
290 &mut param[start..end],
291 &grad[start..end],
292 start,
293 ¶m_info,
294 lr,
295 betas,
296 bias_correction1,
297 bias_correction2,
298 eps,
299 weight_decay,
300 )?;
301 }
302
303 Ok(())
304 }
305
306 fn process_chunk_soa(
308 &mut self,
309 param_chunk: &mut [f32],
310 grad_chunk: &[f32],
311 offset: usize,
312 param_info: &ParameterInfo,
313 lr: f32,
314 betas: (f32, f32),
315 bias_correction1: f32,
316 bias_correction2: f32,
317 eps: f32,
318 weight_decay: f32,
319 ) -> Result<()> {
320 for i in 0..param_chunk.len() {
324 let grad_val = grad_chunk[i] + weight_decay * param_chunk[i];
325
326 let momentum_idx = param_info.momentum_offset + offset + i;
328 let variance_idx = param_info.variance_offset + offset + i;
329
330 let mut momentum = if momentum_idx < param_info.size {
333 grad_val * 0.9 } else {
336 0.0f32
337 };
338
339 let mut variance = if variance_idx < param_info.size {
340 grad_val * grad_val * 0.999 } else {
343 0.0f32
344 };
345
346 ParameterUpdate::update_ema(&mut momentum, grad_val, betas.0);
348 ParameterUpdate::update_ema(&mut variance, grad_val * grad_val, betas.1);
349
350 let m_hat = momentum / bias_correction1;
352 let v_hat = variance / bias_correction2;
353
354 ParameterUpdate::adam_update(&mut param_chunk[i], lr, m_hat, v_hat, eps);
356
357 }
361
362 Ok(())
363 }
364
365 pub fn layout_stats(&self) -> LayoutStats {
367 let momentum_memory = self.momentum_storage.memory_usage();
368 let variance_memory = self.variance_storage.memory_usage();
369 let total_elements: usize = self.parameters.iter().map(|p| p.size).sum();
370
371 LayoutStats {
372 total_parameters: self.parameters.len(),
373 total_elements,
374 momentum_memory_bytes: momentum_memory,
375 variance_memory_bytes: variance_memory,
376 total_memory_bytes: momentum_memory + variance_memory,
377 alignment_config: self.alignment,
378 cache_line_utilization: self.calculate_cache_line_utilization(),
379 }
380 }
381
382 fn calculate_cache_line_utilization(&self) -> f32 {
384 if self.parameters.is_empty() {
385 return 1.0;
386 }
387
388 let cache_line_elements = self.alignment.cache_line_size / std::mem::size_of::<f32>();
389 let mut total_utilization = 0.0;
390
391 for param in &self.parameters {
392 let lines_used = param.size.div_ceil(cache_line_elements);
393 let elements_in_lines = lines_used * cache_line_elements;
394 let utilization = param.size as f32 / elements_in_lines as f32;
395 total_utilization += utilization;
396 }
397
398 total_utilization / self.parameters.len() as f32
399 }
400}
401
402unsafe impl Send for SoAOptimizerState {}
404unsafe impl Sync for SoAOptimizerState {}
405
406#[derive(Debug, Clone)]
408pub struct LayoutStats {
409 pub total_parameters: usize,
411 pub total_elements: usize,
413 pub momentum_memory_bytes: usize,
415 pub variance_memory_bytes: usize,
417 pub total_memory_bytes: usize,
419 pub alignment_config: AlignmentConfig,
421 pub cache_line_utilization: f32,
423}
424
425impl LayoutStats {
426 pub fn memory_overhead(&self) -> f32 {
428 let naive_memory = self.total_elements * std::mem::size_of::<f32>() * 2; if naive_memory == 0 {
430 return 0.0;
431 }
432 (self.total_memory_bytes as f32 / naive_memory as f32) - 1.0
433 }
434
435 pub fn optimization_suggestions(&self) -> Vec<String> {
437 let mut suggestions = Vec::new();
438
439 if self.cache_line_utilization < 0.8 {
440 suggestions.push("Poor cache line utilization; consider parameter padding".to_string());
441 }
442
443 let overhead = self.memory_overhead();
444 if overhead > 0.2 {
445 suggestions.push(format!(
446 "High memory overhead ({:.1}%); review alignment requirements",
447 overhead * 100.0
448 ));
449 }
450
451 if self.alignment_config.vector_size > 32 && self.total_elements < 1000 {
452 suggestions.push("Vector size may be too large for small parameters".to_string());
453 }
454
455 if !self.alignment_config.use_huge_pages && self.total_memory_bytes > 1024 * 1024 {
456 suggestions.push("Consider enabling huge pages for large memory usage".to_string());
457 }
458
459 if suggestions.is_empty() {
460 suggestions.push("Memory layout appears well optimized".to_string());
461 }
462
463 suggestions
464 }
465}
466
467#[derive(Debug)]
469pub struct LayoutOptimizedAdam {
470 lr: f32,
472 betas: (f32, f32),
474 eps: f32,
476 weight_decay: f32,
478 state: SoAOptimizerState,
480}
481
482impl LayoutOptimizedAdam {
483 pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
485 Self::with_alignment(lr, betas, eps, weight_decay, AlignmentConfig::default())
486 }
487
488 pub fn with_alignment(
490 lr: f32,
491 betas: (f32, f32),
492 eps: f32,
493 weight_decay: f32,
494 alignment: AlignmentConfig,
495 ) -> Self {
496 Self {
497 lr,
498 betas,
499 eps,
500 weight_decay,
501 state: SoAOptimizerState::new(alignment),
502 }
503 }
504
505 pub fn avx512_optimized(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
507 Self::with_alignment(lr, betas, eps, weight_decay, AlignmentConfig::avx512())
508 }
509
510 pub fn layout_stats(&self) -> LayoutStats {
512 self.state.layout_stats()
513 }
514
515 pub fn add_parameter(&mut self, id: String, size: usize) -> Result<()> {
517 self.state.add_parameter(id, size)
518 }
519}
520
521impl Optimizer for LayoutOptimizedAdam {
522 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
523 match (parameter, grad) {
524 (Tensor::F32(param), Tensor::F32(grad_arr)) => {
525 let param_id = format!("{:p}", param.as_ptr());
526
527 if self.state.get_parameter_info(¶m_id).is_none() {
529 self.state.add_parameter(param_id.clone(), param.len())?;
530 }
531
532 let param_slice = param.as_slice_mut().ok_or_else(|| {
533 TrustformersError::tensor_op_error(
534 "Failed to get mutable slice from param tensor",
535 "update",
536 )
537 })?;
538 let grad_slice = grad_arr.as_slice().ok_or_else(|| {
539 TrustformersError::tensor_op_error(
540 "Failed to get slice from gradient tensor",
541 "update",
542 )
543 })?;
544 self.state.update_parameter_soa(
545 ¶m_id,
546 param_slice,
547 grad_slice,
548 self.lr,
549 self.betas,
550 self.eps,
551 self.weight_decay,
552 )
553 },
554 _ => Err(TrustformersError::tensor_op_error(
555 "Unsupported tensor types for LayoutOptimizedAdam",
556 "update",
557 )),
558 }
559 }
560
561 fn zero_grad(&mut self) {
562 }
564
565 fn step(&mut self) {
566 }
568
569 fn get_lr(&self) -> f32 {
570 self.lr
571 }
572
573 fn set_lr(&mut self, lr: f32) {
574 self.lr = lr;
575 }
576}
577
578unsafe impl Send for LayoutOptimizedAdam {}
580unsafe impl Sync for LayoutOptimizedAdam {}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[test]
587 fn test_alignment_config() {
588 let config = AlignmentConfig::default();
589 assert_eq!(config.cache_line_size, 64);
590 assert_eq!(config.vector_size, 32);
591 assert!(!config.use_huge_pages);
592
593 let avx512_config = AlignmentConfig::avx512();
594 assert_eq!(avx512_config.vector_size, 64);
595
596 let alignment = config.alignment_for_size(1000);
597 assert!(alignment > 0);
598 assert!(alignment <= config.cache_line_size);
599 }
600
601 #[test]
602 fn test_aligned_allocator() {
603 let config = AlignmentConfig::default();
604 let mut allocator = AlignedAllocator::new(config);
605
606 let _ptr = allocator.allocate_aligned::<f32>(1000).expect("Operation failed in test");
607 let memory_usage = allocator.memory_usage();
610 assert!(memory_usage >= 1000 * std::mem::size_of::<f32>());
611 }
612
613 #[test]
614 fn test_soa_optimizer_state() {
615 let config = AlignmentConfig::default();
616 let mut state = SoAOptimizerState::new(config);
617
618 state
619 .add_parameter("param1".to_string(), 1000)
620 .expect("Operation failed in test");
621 assert!(state.get_parameter_info("param1").is_some());
622
623 let stats = state.layout_stats();
624 assert_eq!(stats.total_parameters, 1);
625 assert_eq!(stats.total_elements, 1000);
626 }
627
628 #[test]
629 fn test_layout_optimized_adam() {
630 let optimizer = LayoutOptimizedAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
631 assert_eq!(optimizer.get_lr(), 1e-3);
632 assert_eq!(optimizer.betas, (0.9, 0.999));
633
634 let stats = optimizer.layout_stats();
635 assert_eq!(stats.total_parameters, 0);
636 }
637
638 #[test]
639 fn test_layout_stats() {
640 let config = AlignmentConfig::default();
641 let mut state = SoAOptimizerState::new(config);
642
643 state
644 .add_parameter("param1".to_string(), 100)
645 .expect("Operation failed in test");
646 state
647 .add_parameter("param2".to_string(), 200)
648 .expect("Operation failed in test");
649
650 let stats = state.layout_stats();
651 assert_eq!(stats.total_parameters, 2);
652 assert_eq!(stats.total_elements, 300);
653 assert!(stats.cache_line_utilization > 0.0);
654 assert!(stats.cache_line_utilization <= 1.0);
655
656 let overhead = stats.memory_overhead();
657 assert!(overhead >= 0.0);
658
659 let suggestions = stats.optimization_suggestions();
660 assert!(!suggestions.is_empty());
661 }
662
663 #[test]
664 fn test_chunk_size_calculation() {
665 let config = AlignmentConfig::default();
666 let state = SoAOptimizerState::new(config);
667
668 let chunk_size_large = state.calculate_optimal_chunk_size(10000);
669 let chunk_size_small = state.calculate_optimal_chunk_size(5);
670
671 assert!(chunk_size_large > chunk_size_small);
672 assert!(
673 chunk_size_large.is_multiple_of(config.vector_size / std::mem::size_of::<f32>())
674 || chunk_size_large == 10000
675 );
676 }
677
678 #[test]
679 fn test_avx512_optimization() {
680 let optimizer = LayoutOptimizedAdam::avx512_optimized(1e-3, (0.9, 0.999), 1e-8, 0.01);
681 let stats = optimizer.layout_stats();
682 assert_eq!(stats.alignment_config.vector_size, 64);
683 }
684}