1use crate::common::{BiasCorrection, ParameterUpdate};
16use std::alloc::{alloc, dealloc, Layout};
17use std::ptr::{self, NonNull};
18use trustformers_core::errors::{Result, TrustformersError};
19use trustformers_core::tensor::Tensor;
20use trustformers_core::traits::Optimizer;
21
22#[derive(Debug, Clone, Copy)]
24pub struct AlignmentConfig {
25 pub cache_line_size: usize,
27 pub vector_size: usize,
29 pub page_size: usize,
31 pub use_huge_pages: bool,
33}
34
35impl Default for AlignmentConfig {
36 fn default() -> Self {
37 Self {
38 cache_line_size: 64,
39 vector_size: 32, page_size: 4096,
41 use_huge_pages: false,
42 }
43 }
44}
45
46impl AlignmentConfig {
47 pub fn avx512() -> Self {
49 Self {
50 vector_size: 64,
51 ..Default::default()
52 }
53 }
54
55 pub fn with_huge_pages() -> Self {
57 Self {
58 use_huge_pages: true,
59 ..Default::default()
60 }
61 }
62
63 pub fn alignment_for_size(&self, size: usize) -> usize {
65 if size >= self.page_size {
66 self.page_size
67 } else if size >= self.cache_line_size {
68 self.cache_line_size
69 } else {
70 self.vector_size.min(size)
71 }
72 }
73}
74
75#[derive(Debug)]
77pub struct AlignedAllocator {
78 config: AlignmentConfig,
79 allocated_blocks: Vec<(NonNull<u8>, Layout)>,
80}
81
82impl AlignedAllocator {
83 pub fn new(config: AlignmentConfig) -> Self {
85 Self {
86 config,
87 allocated_blocks: Vec::new(),
88 }
89 }
90
91 pub fn allocate_aligned<T>(&mut self, count: usize) -> Result<NonNull<T>> {
93 let size = count * std::mem::size_of::<T>();
94 let alignment = self.config.alignment_for_size(size);
95
96 let layout = Layout::from_size_align(size, alignment).map_err(|e| {
97 TrustformersError::tensor_op_error(
98 &format!("Invalid layout: {}", e),
99 "allocate_aligned",
100 )
101 })?;
102
103 let ptr = unsafe { alloc(layout) };
104 if ptr.is_null() {
105 return Err(TrustformersError::tensor_op_error(
106 "Memory allocation failed",
107 "allocate_aligned",
108 ));
109 }
110
111 let non_null = NonNull::new(ptr).ok_or_else(|| {
112 TrustformersError::tensor_op_error("Null pointer in allocation", "allocate_aligned")
113 })?;
114
115 self.allocated_blocks.push((non_null, layout));
116
117 let typed_ptr = non_null.as_ptr() as *mut T;
119 NonNull::new(typed_ptr).ok_or_else(|| {
120 TrustformersError::tensor_op_error("Type casting failed", "allocate_aligned")
121 })
122 }
123
124 pub fn allocate_initialized<T: Clone>(&mut self, count: usize, value: T) -> Result<NonNull<T>> {
126 let ptr = self.allocate_aligned::<T>(count)?;
127
128 unsafe {
129 for i in 0..count {
130 ptr::write(ptr.as_ptr().add(i), value.clone());
131 }
132 }
133
134 Ok(ptr)
135 }
136
137 pub fn memory_usage(&self) -> usize {
139 self.allocated_blocks.iter().map(|(_, layout)| layout.size()).sum()
140 }
141}
142
143impl Drop for AlignedAllocator {
144 fn drop(&mut self) {
145 for (ptr, layout) in &self.allocated_blocks {
146 unsafe {
147 dealloc(ptr.as_ptr(), *layout);
148 }
149 }
150 }
151}
152
153unsafe impl Send for AlignedAllocator {}
156unsafe impl Sync for AlignedAllocator {}
157
158#[derive(Debug)]
163pub struct SoAOptimizerState {
164 momentum_storage: AlignedAllocator,
166 variance_storage: AlignedAllocator,
168 parameters: Vec<ParameterInfo>,
170 step: usize,
172 alignment: AlignmentConfig,
174}
175
176#[derive(Debug, Clone)]
178pub struct ParameterInfo {
179 pub id: String,
181 pub momentum_offset: usize,
183 pub variance_offset: usize,
185 pub size: usize,
187 pub chunk_size: usize,
189}
190
191impl SoAOptimizerState {
192 pub fn new(alignment: AlignmentConfig) -> Self {
194 Self {
195 momentum_storage: AlignedAllocator::new(alignment),
196 variance_storage: AlignedAllocator::new(alignment),
197 parameters: Vec::new(),
198 step: 0,
199 alignment,
200 }
201 }
202
203 pub fn add_parameter(&mut self, id: String, size: usize) -> Result<()> {
205 let chunk_size = self.calculate_optimal_chunk_size(size);
207
208 let _momentum_ptr = self.momentum_storage.allocate_initialized(size, 0.0f32)?;
210 let momentum_offset = self.parameters.len() * size; let _variance_ptr = self.variance_storage.allocate_initialized(size, 0.0f32)?;
214 let variance_offset = self.parameters.len() * size; let param_info = ParameterInfo {
217 id,
218 momentum_offset,
219 variance_offset,
220 size,
221 chunk_size,
222 };
223
224 self.parameters.push(param_info);
225 Ok(())
226 }
227
228 fn calculate_optimal_chunk_size(&self, size: usize) -> usize {
230 let vector_elements = self.alignment.vector_size / std::mem::size_of::<f32>();
231 let cache_line_elements = self.alignment.cache_line_size / std::mem::size_of::<f32>();
232
233 let min_chunk = vector_elements;
235 let preferred_chunk = cache_line_elements;
236
237 if size >= preferred_chunk {
238 preferred_chunk
239 } else if size >= min_chunk {
240 (size / min_chunk) * min_chunk
242 } else {
243 size
244 }
245 }
246
247 pub fn get_parameter_info(&self, id: &str) -> Option<&ParameterInfo> {
249 self.parameters.iter().find(|p| p.id == id)
250 }
251
252 pub fn update_parameter_soa(
254 &mut self,
255 param_id: &str,
256 param: &mut [f32],
257 grad: &[f32],
258 lr: f32,
259 betas: (f32, f32),
260 eps: f32,
261 weight_decay: f32,
262 ) -> Result<()> {
263 let param_info = self
264 .get_parameter_info(param_id)
265 .ok_or_else(|| {
266 TrustformersError::tensor_op_error("Parameter not found", "update_parameter_soa")
267 })?
268 .clone();
269
270 if param.len() != param_info.size || grad.len() != param_info.size {
271 return Err(TrustformersError::tensor_op_error(
272 "Size mismatch",
273 "update_parameter_soa",
274 ));
275 }
276
277 self.step += 1;
278 let (bias_correction1, bias_correction2) =
279 BiasCorrection::compute_adam_corrections(betas.0, betas.1, self.step);
280
281 let chunk_size = param_info.chunk_size;
283 let num_chunks = param_info.size.div_ceil(chunk_size);
284
285 for chunk_idx in 0..num_chunks {
286 let start = chunk_idx * chunk_size;
287 let end = (start + chunk_size).min(param_info.size);
288
289 self.process_chunk_soa(
290 &mut param[start..end],
291 &grad[start..end],
292 start,
293 ¶m_info,
294 lr,
295 betas,
296 bias_correction1,
297 bias_correction2,
298 eps,
299 weight_decay,
300 )?;
301 }
302
303 Ok(())
304 }
305
306 fn process_chunk_soa(
308 &mut self,
309 param_chunk: &mut [f32],
310 grad_chunk: &[f32],
311 offset: usize,
312 param_info: &ParameterInfo,
313 lr: f32,
314 betas: (f32, f32),
315 bias_correction1: f32,
316 bias_correction2: f32,
317 eps: f32,
318 weight_decay: f32,
319 ) -> Result<()> {
320 for i in 0..param_chunk.len() {
324 let grad_val = grad_chunk[i] + weight_decay * param_chunk[i];
325
326 let momentum_idx = param_info.momentum_offset + offset + i;
328 let variance_idx = param_info.variance_offset + offset + i;
329
330 let mut momentum = if momentum_idx < param_info.size {
333 grad_val * 0.9 } else {
336 0.0f32
337 };
338
339 let mut variance = if variance_idx < param_info.size {
340 grad_val * grad_val * 0.999 } else {
343 0.0f32
344 };
345
346 ParameterUpdate::update_ema(&mut momentum, grad_val, betas.0);
348 ParameterUpdate::update_ema(&mut variance, grad_val * grad_val, betas.1);
349
350 let m_hat = momentum / bias_correction1;
352 let v_hat = variance / bias_correction2;
353
354 ParameterUpdate::adam_update(&mut param_chunk[i], lr, m_hat, v_hat, eps);
356
357 }
361
362 Ok(())
363 }
364
365 pub fn layout_stats(&self) -> LayoutStats {
367 let momentum_memory = self.momentum_storage.memory_usage();
368 let variance_memory = self.variance_storage.memory_usage();
369 let total_elements: usize = self.parameters.iter().map(|p| p.size).sum();
370
371 LayoutStats {
372 total_parameters: self.parameters.len(),
373 total_elements,
374 momentum_memory_bytes: momentum_memory,
375 variance_memory_bytes: variance_memory,
376 total_memory_bytes: momentum_memory + variance_memory,
377 alignment_config: self.alignment,
378 cache_line_utilization: self.calculate_cache_line_utilization(),
379 }
380 }
381
382 fn calculate_cache_line_utilization(&self) -> f32 {
384 if self.parameters.is_empty() {
385 return 1.0;
386 }
387
388 let cache_line_elements = self.alignment.cache_line_size / std::mem::size_of::<f32>();
389 let mut total_utilization = 0.0;
390
391 for param in &self.parameters {
392 let lines_used = param.size.div_ceil(cache_line_elements);
393 let elements_in_lines = lines_used * cache_line_elements;
394 let utilization = param.size as f32 / elements_in_lines as f32;
395 total_utilization += utilization;
396 }
397
398 total_utilization / self.parameters.len() as f32
399 }
400}
401
402unsafe impl Send for SoAOptimizerState {}
404unsafe impl Sync for SoAOptimizerState {}
405
406#[derive(Debug, Clone)]
408pub struct LayoutStats {
409 pub total_parameters: usize,
411 pub total_elements: usize,
413 pub momentum_memory_bytes: usize,
415 pub variance_memory_bytes: usize,
417 pub total_memory_bytes: usize,
419 pub alignment_config: AlignmentConfig,
421 pub cache_line_utilization: f32,
423}
424
425impl LayoutStats {
426 pub fn memory_overhead(&self) -> f32 {
428 let naive_memory = self.total_elements * std::mem::size_of::<f32>() * 2; if naive_memory == 0 {
430 return 0.0;
431 }
432 (self.total_memory_bytes as f32 / naive_memory as f32) - 1.0
433 }
434
435 pub fn optimization_suggestions(&self) -> Vec<String> {
437 let mut suggestions = Vec::new();
438
439 if self.cache_line_utilization < 0.8 {
440 suggestions.push("Poor cache line utilization; consider parameter padding".to_string());
441 }
442
443 let overhead = self.memory_overhead();
444 if overhead > 0.2 {
445 suggestions.push(format!(
446 "High memory overhead ({:.1}%); review alignment requirements",
447 overhead * 100.0
448 ));
449 }
450
451 if self.alignment_config.vector_size > 32 && self.total_elements < 1000 {
452 suggestions.push("Vector size may be too large for small parameters".to_string());
453 }
454
455 if !self.alignment_config.use_huge_pages && self.total_memory_bytes > 1024 * 1024 {
456 suggestions.push("Consider enabling huge pages for large memory usage".to_string());
457 }
458
459 if suggestions.is_empty() {
460 suggestions.push("Memory layout appears well optimized".to_string());
461 }
462
463 suggestions
464 }
465}
466
467#[derive(Debug)]
469pub struct LayoutOptimizedAdam {
470 lr: f32,
472 betas: (f32, f32),
474 eps: f32,
476 weight_decay: f32,
478 state: SoAOptimizerState,
480}
481
482impl LayoutOptimizedAdam {
483 pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
485 Self::with_alignment(lr, betas, eps, weight_decay, AlignmentConfig::default())
486 }
487
488 pub fn with_alignment(
490 lr: f32,
491 betas: (f32, f32),
492 eps: f32,
493 weight_decay: f32,
494 alignment: AlignmentConfig,
495 ) -> Self {
496 Self {
497 lr,
498 betas,
499 eps,
500 weight_decay,
501 state: SoAOptimizerState::new(alignment),
502 }
503 }
504
505 pub fn avx512_optimized(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
507 Self::with_alignment(lr, betas, eps, weight_decay, AlignmentConfig::avx512())
508 }
509
510 pub fn layout_stats(&self) -> LayoutStats {
512 self.state.layout_stats()
513 }
514
515 pub fn add_parameter(&mut self, id: String, size: usize) -> Result<()> {
517 self.state.add_parameter(id, size)
518 }
519}
520
521impl Optimizer for LayoutOptimizedAdam {
522 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
523 match (parameter, grad) {
524 (Tensor::F32(param), Tensor::F32(grad_arr)) => {
525 let param_id = format!("{:p}", param.as_ptr());
526
527 if self.state.get_parameter_info(¶m_id).is_none() {
529 self.state.add_parameter(param_id.clone(), param.len())?;
530 }
531
532 self.state.update_parameter_soa(
533 ¶m_id,
534 param.as_slice_mut().unwrap(),
535 grad_arr.as_slice().unwrap(),
536 self.lr,
537 self.betas,
538 self.eps,
539 self.weight_decay,
540 )
541 },
542 _ => Err(TrustformersError::tensor_op_error(
543 "Unsupported tensor types for LayoutOptimizedAdam",
544 "update",
545 )),
546 }
547 }
548
549 fn zero_grad(&mut self) {
550 }
552
553 fn step(&mut self) {
554 }
556
557 fn get_lr(&self) -> f32 {
558 self.lr
559 }
560
561 fn set_lr(&mut self, lr: f32) {
562 self.lr = lr;
563 }
564}
565
566unsafe impl Send for LayoutOptimizedAdam {}
568unsafe impl Sync for LayoutOptimizedAdam {}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 #[test]
575 fn test_alignment_config() {
576 let config = AlignmentConfig::default();
577 assert_eq!(config.cache_line_size, 64);
578 assert_eq!(config.vector_size, 32);
579 assert!(!config.use_huge_pages);
580
581 let avx512_config = AlignmentConfig::avx512();
582 assert_eq!(avx512_config.vector_size, 64);
583
584 let alignment = config.alignment_for_size(1000);
585 assert!(alignment > 0);
586 assert!(alignment <= config.cache_line_size);
587 }
588
589 #[test]
590 fn test_aligned_allocator() {
591 let config = AlignmentConfig::default();
592 let mut allocator = AlignedAllocator::new(config);
593
594 let _ptr = allocator.allocate_aligned::<f32>(1000).unwrap();
595 let memory_usage = allocator.memory_usage();
598 assert!(memory_usage >= 1000 * std::mem::size_of::<f32>());
599 }
600
601 #[test]
602 fn test_soa_optimizer_state() {
603 let config = AlignmentConfig::default();
604 let mut state = SoAOptimizerState::new(config);
605
606 state.add_parameter("param1".to_string(), 1000).unwrap();
607 assert!(state.get_parameter_info("param1").is_some());
608
609 let stats = state.layout_stats();
610 assert_eq!(stats.total_parameters, 1);
611 assert_eq!(stats.total_elements, 1000);
612 }
613
614 #[test]
615 fn test_layout_optimized_adam() {
616 let optimizer = LayoutOptimizedAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
617 assert_eq!(optimizer.get_lr(), 1e-3);
618 assert_eq!(optimizer.betas, (0.9, 0.999));
619
620 let stats = optimizer.layout_stats();
621 assert_eq!(stats.total_parameters, 0);
622 }
623
624 #[test]
625 fn test_layout_stats() {
626 let config = AlignmentConfig::default();
627 let mut state = SoAOptimizerState::new(config);
628
629 state.add_parameter("param1".to_string(), 100).unwrap();
630 state.add_parameter("param2".to_string(), 200).unwrap();
631
632 let stats = state.layout_stats();
633 assert_eq!(stats.total_parameters, 2);
634 assert_eq!(stats.total_elements, 300);
635 assert!(stats.cache_line_utilization > 0.0);
636 assert!(stats.cache_line_utilization <= 1.0);
637
638 let overhead = stats.memory_overhead();
639 assert!(overhead >= 0.0);
640
641 let suggestions = stats.optimization_suggestions();
642 assert!(!suggestions.is_empty());
643 }
644
645 #[test]
646 fn test_chunk_size_calculation() {
647 let config = AlignmentConfig::default();
648 let state = SoAOptimizerState::new(config);
649
650 let chunk_size_large = state.calculate_optimal_chunk_size(10000);
651 let chunk_size_small = state.calculate_optimal_chunk_size(5);
652
653 assert!(chunk_size_large > chunk_size_small);
654 assert!(
655 chunk_size_large % (config.vector_size / std::mem::size_of::<f32>()) == 0
656 || chunk_size_large == 10000
657 );
658 }
659
660 #[test]
661 fn test_avx512_optimization() {
662 let optimizer = LayoutOptimizedAdam::avx512_optimized(1e-3, (0.9, 0.999), 1e-8, 0.01);
663 let stats = optimizer.layout_stats();
664 assert_eq!(stats.alignment_config.vector_size, 64);
665 }
666}