1use crate::common::{BiasCorrection, ParameterUpdate};
15use std::collections::HashMap;
16use trustformers_core::errors::{Result, TrustformersError};
17use trustformers_core::tensor::Tensor;
18use trustformers_core::traits::Optimizer;
19
20#[derive(Debug, Clone)]
22pub struct CacheConfig {
23 pub l1_cache_size: usize,
25 pub l2_cache_size: usize,
27 pub l3_cache_size: usize,
29 pub cache_line_size: usize,
31 pub block_size: usize,
33 pub enable_prefetching: bool,
35 pub prefetch_distance: usize,
37}
38
39impl Default for CacheConfig {
40 fn default() -> Self {
41 Self {
42 l1_cache_size: 32 * 1024, l2_cache_size: 256 * 1024, l3_cache_size: 8 * 1024 * 1024, cache_line_size: 64, block_size: 1024, enable_prefetching: true,
48 prefetch_distance: 4,
49 }
50 }
51}
52
53impl CacheConfig {
54 pub fn detect_system() -> Self {
56 Self::default()
59 }
60
61 pub fn l1_optimized() -> Self {
63 Self {
64 block_size: 512, ..Default::default()
66 }
67 }
68
69 pub fn l2_optimized() -> Self {
71 Self {
72 block_size: 2048, ..Default::default()
74 }
75 }
76
77 pub fn l3_optimized() -> Self {
79 Self {
80 block_size: 8192, ..Default::default()
82 }
83 }
84
85 pub fn optimal_block_size_for_arrays(&self, num_arrays: usize) -> usize {
87 let available_cache = self.l2_cache_size / num_arrays;
89 let elements_per_cache = available_cache / std::mem::size_of::<f32>();
90
91 let mut block_size = 64;
93 while block_size * 2 <= elements_per_cache && block_size < 16384 {
94 block_size *= 2;
95 }
96
97 block_size.min(self.block_size)
98 }
99}
100
101#[derive(Debug)]
106pub struct CacheFriendlyState {
107 pub interleaved_buffers: HashMap<usize, Vec<f32>>,
110 pub param_metadata: HashMap<usize, ParameterMetadata>,
112 pub step: usize,
114 pub cache_config: CacheConfig,
116}
117
118#[derive(Debug, Clone)]
120pub struct ParameterMetadata {
121 pub offset: usize,
123 pub size: usize,
125 pub block_size: usize,
127 pub last_access: usize,
129}
130
131impl CacheFriendlyState {
132 pub fn new(cache_config: CacheConfig) -> Self {
134 Self {
135 interleaved_buffers: HashMap::new(),
136 param_metadata: HashMap::new(),
137 step: 0,
138 cache_config,
139 }
140 }
141
142 pub fn allocate_parameter(&mut self, param_id: usize, size: usize) -> Result<()> {
144 let buffer_size = size * 2; let buffer = vec![0.0; buffer_size];
148
149 let metadata = ParameterMetadata {
150 offset: 0,
151 size,
152 block_size: self.cache_config.optimal_block_size_for_arrays(3), last_access: self.step,
154 };
155
156 self.interleaved_buffers.insert(param_id, buffer);
157 self.param_metadata.insert(param_id, metadata);
158
159 Ok(())
160 }
161
162 pub fn get_interleaved_buffer_mut(&mut self, param_id: usize) -> Option<(&mut [f32], usize)> {
164 if let (Some(buffer), Some(metadata)) = (
165 self.interleaved_buffers.get_mut(¶m_id),
166 self.param_metadata.get_mut(¶m_id),
167 ) {
168 metadata.last_access = self.step;
169 Some((buffer.as_mut_slice(), metadata.size))
170 } else {
171 None
172 }
173 }
174
175 pub fn get_buffers_mut(&mut self, param_id: usize) -> Option<(Vec<f32>, Vec<f32>)> {
178 if let (Some(buffer), Some(metadata)) = (
179 self.interleaved_buffers.get(¶m_id),
180 self.param_metadata.get_mut(¶m_id),
181 ) {
182 metadata.last_access = self.step;
183
184 let mut momentum = Vec::with_capacity(metadata.size);
186 let mut variance = Vec::with_capacity(metadata.size);
187
188 for i in 0..metadata.size {
189 momentum.push(buffer[i * 2]);
190 variance.push(buffer[i * 2 + 1]);
191 }
192
193 Some((momentum, variance))
194 } else {
195 None
196 }
197 }
198
199 pub fn update_buffers(
201 &mut self,
202 param_id: usize,
203 momentum: &[f32],
204 variance: &[f32],
205 ) -> Result<()> {
206 if let Some(buffer) = self.interleaved_buffers.get_mut(¶m_id) {
207 if momentum.len() != variance.len() || momentum.len() * 2 != buffer.len() {
208 return Err(TrustformersError::tensor_op_error(
209 "Buffer size mismatch",
210 "update_buffers",
211 ));
212 }
213
214 for i in 0..momentum.len() {
216 buffer[i * 2] = momentum[i];
217 buffer[i * 2 + 1] = variance[i];
218 }
219
220 Ok(())
221 } else {
222 Err(TrustformersError::tensor_op_error(
223 "Parameter not found",
224 "update_buffers",
225 ))
226 }
227 }
228
229 pub fn garbage_collect(&mut self, access_threshold: usize) {
231 let current_step = self.step;
232 let stale_params: Vec<usize> = self
233 .param_metadata
234 .iter()
235 .filter(|(_, metadata)| current_step - metadata.last_access > access_threshold)
236 .map(|(id, _)| *id)
237 .collect();
238
239 for param_id in stale_params {
240 self.interleaved_buffers.remove(¶m_id);
241 self.param_metadata.remove(¶m_id);
242 }
243 }
244}
245
246#[derive(Debug)]
251pub struct CacheFriendlyAdam {
252 lr: f32,
254 betas: (f32, f32),
256 eps: f32,
258 weight_decay: f32,
260 state: CacheFriendlyState,
262}
263
264impl CacheFriendlyAdam {
265 pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
267 Self::with_cache_config(lr, betas, eps, weight_decay, CacheConfig::default())
268 }
269
270 pub fn with_cache_config(
272 lr: f32,
273 betas: (f32, f32),
274 eps: f32,
275 weight_decay: f32,
276 cache_config: CacheConfig,
277 ) -> Self {
278 Self {
279 lr,
280 betas,
281 eps,
282 weight_decay,
283 state: CacheFriendlyState::new(cache_config),
284 }
285 }
286
287 #[allow(dead_code)]
289 fn update_parameter_blocked(
290 &mut self,
291 param: &mut [f32],
292 grad: &[f32],
293 param_id: String,
294 ) -> Result<()> {
295 let numeric_id = param_id.as_ptr() as usize;
297 self.update_parameter_blocked_fast(param, grad, numeric_id)
298 }
299
300 fn update_parameter_blocked_fast(
302 &mut self,
303 param: &mut [f32],
304 grad: &[f32],
305 param_id: usize,
306 ) -> Result<()> {
307 let size = param.len();
308 if grad.len() != size {
309 return Err(TrustformersError::tensor_op_error(
310 "Parameter and gradient size mismatch",
311 "update_parameter_blocked_fast",
312 ));
313 }
314
315 if !self.state.param_metadata.contains_key(¶m_id) {
317 self.state.allocate_parameter(param_id, size)?;
318 } else {
319 let current_size =
321 self.state.param_metadata.get(¶m_id).map(|meta| meta.size).unwrap_or(0);
322 if current_size != size {
323 self.state.allocate_parameter(param_id, size)?;
324 }
325 }
326
327 let step = self.state.step + 1;
329 let block_size = self
330 .state
331 .param_metadata
332 .get(¶m_id)
333 .map(|meta| meta.block_size)
334 .unwrap_or(1024);
335 let _enable_prefetching = self.state.cache_config.enable_prefetching;
336
337 let (bias_correction1, bias_correction2) =
338 BiasCorrection::compute_adam_corrections(self.betas.0, self.betas.1, step);
339
340 let (interleaved_buffer, _param_size) =
342 self.state.get_interleaved_buffer_mut(param_id).ok_or_else(|| {
343 TrustformersError::tensor_op_error(
344 "Failed to get parameter buffers",
345 "update_parameter_blocked_fast",
346 )
347 })?;
348
349 if size < 4096 {
351 for i in 0..size {
353 let grad_val = grad[i] + self.weight_decay * param[i];
354
355 let momentum_idx = i * 2;
357 let variance_idx = i * 2 + 1;
358
359 interleaved_buffer[momentum_idx] = self.betas.0 * interleaved_buffer[momentum_idx]
361 + (1.0 - self.betas.0) * grad_val;
362 interleaved_buffer[variance_idx] = self.betas.1 * interleaved_buffer[variance_idx]
363 + (1.0 - self.betas.1) * grad_val * grad_val;
364
365 let m_hat = interleaved_buffer[momentum_idx] / bias_correction1;
367 let v_hat = interleaved_buffer[variance_idx] / bias_correction2;
368
369 param[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
370 }
371 } else {
372 let num_blocks = size.div_ceil(block_size);
374
375 for block_idx in 0..num_blocks {
376 let start = block_idx * block_size;
377 let end = (start + block_size).min(size);
378
379 for i in start..end {
382 let grad_val = grad[i] + self.weight_decay * param[i];
383
384 let momentum_idx = i * 2;
386 let variance_idx = i * 2 + 1;
387
388 interleaved_buffer[momentum_idx] = self.betas.0
389 * interleaved_buffer[momentum_idx]
390 + (1.0 - self.betas.0) * grad_val;
391 interleaved_buffer[variance_idx] = self.betas.1
392 * interleaved_buffer[variance_idx]
393 + (1.0 - self.betas.1) * grad_val * grad_val;
394
395 let m_hat = interleaved_buffer[momentum_idx] / bias_correction1;
396 let v_hat = interleaved_buffer[variance_idx] / bias_correction2;
397
398 param[i] -= self.lr * m_hat / (v_hat.sqrt() + self.eps);
399 }
400 }
401 }
402
403 Ok(())
405 }
406
407 #[inline]
409 #[allow(dead_code)]
410 fn process_block_fused(
411 &self,
412 param_block: &mut [f32],
413 grad_block: &[f32],
414 momentum_block: &mut [f32],
415 variance_block: &mut [f32],
416 bias_correction1: f32,
417 bias_correction2: f32,
418 ) {
419 for i in 0..param_block.len() {
421 let grad_val = grad_block[i] + self.weight_decay * param_block[i];
422
423 ParameterUpdate::update_ema(&mut momentum_block[i], grad_val, self.betas.0);
425 ParameterUpdate::update_ema(&mut variance_block[i], grad_val * grad_val, self.betas.1);
426
427 let m_hat = momentum_block[i] / bias_correction1;
429 let v_hat = variance_block[i] / bias_correction2;
430
431 ParameterUpdate::adam_update(&mut param_block[i], self.lr, m_hat, v_hat, self.eps);
432 }
433 }
434
435 #[inline]
437 #[allow(dead_code)]
438 fn prefetch_block(&self, block: &[f32]) {
439 if block.is_empty() {
441 return;
442 }
443
444 let ptr = block.as_ptr();
446
447 #[cfg(target_arch = "x86_64")]
449 {
450 unsafe {
454 std::arch::x86_64::_mm_prefetch(ptr as *const i8, std::arch::x86_64::_MM_HINT_T0);
455
456 if block.len() > 16 {
458 let mid_ptr = ptr.wrapping_add(block.len() / 2);
460 std::arch::x86_64::_mm_prefetch(
461 mid_ptr as *const i8,
462 std::arch::x86_64::_MM_HINT_T0,
463 );
464 }
465 }
466 }
467
468 #[cfg(target_arch = "aarch64")]
469 {
470 unsafe {
472 std::arch::asm!(
473 "prfm pldl1keep, [{}]",
474 in(reg) ptr,
475 options(nostack, preserves_flags)
476 );
477 }
478 }
479
480 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
481 {
482 let _ = unsafe { std::ptr::read_volatile(ptr) };
486 }
487 }
488
489 pub fn cache_stats(&self) -> CacheStats {
491 let buffer_memory: usize = self
492 .state
493 .interleaved_buffers
494 .values()
495 .map(|buffer| buffer.len() * std::mem::size_of::<f32>())
496 .sum();
497
498 let num_params = self.state.param_metadata.len();
499 let total_elements: usize = self.state.param_metadata.values().map(|meta| meta.size).sum();
500
501 CacheStats {
502 buffer_memory_bytes: buffer_memory,
503 num_parameters: num_params,
504 total_elements,
505 cache_config: self.state.cache_config.clone(),
506 estimated_l1_utilization: self
507 .estimate_cache_utilization(buffer_memory, self.state.cache_config.l1_cache_size),
508 estimated_l2_utilization: self
509 .estimate_cache_utilization(buffer_memory, self.state.cache_config.l2_cache_size),
510 }
511 }
512
513 fn estimate_cache_utilization(&self, working_set_size: usize, cache_size: usize) -> f32 {
515 if cache_size == 0 {
516 return 1.0;
517 }
518 (working_set_size as f32 / cache_size as f32).min(1.0)
519 }
520
521 pub fn cleanup_unused_params(&mut self, steps_threshold: usize) {
523 self.state.garbage_collect(steps_threshold);
524 }
525}
526
527impl Optimizer for CacheFriendlyAdam {
528 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
529 match (parameter, grad) {
530 (Tensor::F32(param), Tensor::F32(grad_arr)) => {
531 let param_id = param.as_ptr() as usize;
533 let param_slice = param.as_slice_mut().ok_or_else(|| {
534 TrustformersError::tensor_op_error(
535 "Parameter tensor is not contiguous",
536 "update",
537 )
538 })?;
539 let grad_slice = grad_arr.as_slice().ok_or_else(|| {
540 TrustformersError::tensor_op_error(
541 "Gradient tensor is not contiguous",
542 "update",
543 )
544 })?;
545 self.update_parameter_blocked_fast(param_slice, grad_slice, param_id)
546 },
547 _ => Err(TrustformersError::tensor_op_error(
548 "Unsupported tensor types for CacheFriendlyAdam",
549 "update",
550 )),
551 }
552 }
553
554 fn zero_grad(&mut self) {
555 }
557
558 fn step(&mut self) {
559 self.state.step += 1;
560 }
561
562 fn get_lr(&self) -> f32 {
563 self.lr
564 }
565
566 fn set_lr(&mut self, lr: f32) {
567 self.lr = lr;
568 }
569}
570
571#[derive(Debug, Clone)]
573pub struct CacheStats {
574 pub buffer_memory_bytes: usize,
576 pub num_parameters: usize,
578 pub total_elements: usize,
580 pub cache_config: CacheConfig,
582 pub estimated_l1_utilization: f32,
584 pub estimated_l2_utilization: f32,
586}
587
588impl CacheStats {
589 pub fn optimization_suggestions(&self) -> Vec<String> {
591 let mut suggestions = Vec::new();
592
593 if self.estimated_l1_utilization > 0.8 {
594 suggestions.push("Consider reducing block size for better L1 cache fit".to_string());
595 }
596
597 if self.estimated_l2_utilization > 0.9 {
598 suggestions
599 .push("Working set exceeds L2 cache; consider parameter partitioning".to_string());
600 }
601
602 if self.cache_config.block_size > 8192 {
603 suggestions.push("Large block size may cause cache thrashing".to_string());
604 }
605
606 if !self.cache_config.enable_prefetching {
607 suggestions.push("Enable prefetching for potential performance gains".to_string());
608 }
609
610 if suggestions.is_empty() {
611 suggestions.push("Cache utilization appears optimal".to_string());
612 }
613
614 suggestions
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 #[test]
623 fn test_cache_config_creation() {
624 let config = CacheConfig::default();
625 assert_eq!(config.l1_cache_size, 32 * 1024);
626 assert_eq!(config.cache_line_size, 64);
627 assert!(config.enable_prefetching);
628
629 let l1_config = CacheConfig::l1_optimized();
630 assert_eq!(l1_config.block_size, 512);
631 }
632
633 #[test]
634 fn test_optimal_block_size() {
635 let config = CacheConfig::default();
636 let block_size = config.optimal_block_size_for_arrays(3);
637 assert!(block_size > 0);
638 assert!(block_size <= config.block_size);
639 assert_eq!(block_size & (block_size - 1), 0); }
641
642 #[test]
643 fn test_cache_friendly_state() {
644 let mut state = CacheFriendlyState::new(CacheConfig::default());
645
646 let param_id = 12345usize;
648 state.allocate_parameter(param_id, 100).unwrap();
649
650 assert!(state.param_metadata.contains_key(¶m_id));
651 assert!(state.interleaved_buffers.contains_key(¶m_id));
652
653 let (momentum, variance) = state.get_buffers_mut(param_id).unwrap();
655 assert_eq!(momentum.len(), 100);
656 assert_eq!(variance.len(), 100);
657 }
658
659 #[test]
660 fn test_cache_friendly_adam() {
661 let optimizer = CacheFriendlyAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
662 assert_eq!(optimizer.get_lr(), 1e-3);
663 assert_eq!(optimizer.betas, (0.9, 0.999));
664 assert_eq!(optimizer.eps, 1e-8);
665 assert_eq!(optimizer.weight_decay, 0.01);
666 }
667
668 #[test]
669 fn test_cache_stats() {
670 let optimizer = CacheFriendlyAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
671 let stats = optimizer.cache_stats();
672
673 assert_eq!(stats.num_parameters, 0);
674 assert_eq!(stats.total_elements, 0);
675 assert_eq!(stats.buffer_memory_bytes, 0);
676
677 let suggestions = stats.optimization_suggestions();
678 assert!(!suggestions.is_empty());
679 }
680
681 #[test]
682 fn test_cache_utilization_estimation() {
683 let optimizer = CacheFriendlyAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
684
685 let utilization = optimizer.estimate_cache_utilization(16 * 1024, 32 * 1024);
686 assert_eq!(utilization, 0.5);
687
688 let over_utilization = optimizer.estimate_cache_utilization(64 * 1024, 32 * 1024);
689 assert_eq!(over_utilization, 1.0);
690 }
691
692 #[test]
693 fn test_garbage_collection() {
694 let mut state = CacheFriendlyState::new(CacheConfig::default());
695
696 let param1_id = 11111usize;
698 let param2_id = 22222usize;
699 state.allocate_parameter(param1_id, 100).unwrap();
700 state.allocate_parameter(param2_id, 200).unwrap();
701
702 state.step = 1000;
704
705 state.get_buffers_mut(param1_id);
707
708 state.garbage_collect(10);
710
711 assert!(state.param_metadata.contains_key(¶m1_id));
713 assert!(!state.param_metadata.contains_key(¶m2_id));
714 }
715}