1use crate::common::{BiasCorrection, ParameterUpdate, StateMemoryStats};
15use scirs2_core::parallel_ops::*; use std::collections::HashMap;
17use std::sync::{Arc, Mutex, RwLock};
18use trustformers_core::errors::{Result, TrustformersError};
19use trustformers_core::tensor::Tensor;
20use trustformers_core::traits::Optimizer;
21
22#[derive(Debug, Clone)]
24pub struct ParallelConfig {
25 pub num_threads: usize,
27 pub min_params_per_thread: usize,
29 pub enable_work_stealing: bool,
31 pub numa_aware: bool,
33 pub chunk_size: usize,
35 pub lock_free: bool,
37}
38
39impl Default for ParallelConfig {
40 fn default() -> Self {
41 Self {
42 num_threads: 0, min_params_per_thread: 1000,
44 enable_work_stealing: true,
45 numa_aware: false,
46 chunk_size: 1024,
47 lock_free: true,
48 }
49 }
50}
51
52impl ParallelConfig {
53 pub fn cpu_optimized() -> Self {
55 Self {
56 num_threads: num_cpus::get(),
57 chunk_size: 512,
58 enable_work_stealing: true,
59 ..Default::default()
60 }
61 }
62
63 pub fn large_model() -> Self {
65 Self {
66 num_threads: num_cpus::get(),
67 min_params_per_thread: 10000,
68 chunk_size: 4096,
69 numa_aware: true,
70 ..Default::default()
71 }
72 }
73
74 pub fn memory_bound() -> Self {
76 Self {
77 num_threads: (num_cpus::get() / 2).max(1),
78 chunk_size: 2048,
79 numa_aware: true,
80 ..Default::default()
81 }
82 }
83
84 pub fn effective_num_threads(&self) -> usize {
86 if self.num_threads == 0 {
87 num_cpus::get()
88 } else {
89 self.num_threads
90 }
91 }
92}
93
94#[derive(Debug)]
96pub struct ParallelOptimizerState {
97 parameter_states: RwLock<HashMap<String, Arc<Mutex<ParameterState>>>>,
99 global_step: Arc<std::sync::atomic::AtomicUsize>,
101 config: ParallelConfig,
103}
104
105#[derive(Debug)]
107pub struct ParameterState {
108 pub momentum: Vec<f32>,
109 pub variance: Vec<f32>,
110 pub step: usize,
111 pub last_update: std::time::Instant,
112}
113
114impl ParameterState {
115 fn new(size: usize) -> Self {
116 Self {
117 momentum: vec![0.0; size],
118 variance: vec![0.0; size],
119 step: 0,
120 last_update: std::time::Instant::now(),
121 }
122 }
123}
124
125impl ParallelOptimizerState {
126 pub fn new(config: ParallelConfig) -> Self {
128 Self {
129 parameter_states: RwLock::new(HashMap::new()),
130 global_step: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
131 config,
132 }
133 }
134
135 pub fn get_or_create_state(&self, param_id: String, size: usize) -> Arc<Mutex<ParameterState>> {
137 {
139 let states = self.parameter_states.read().unwrap();
140 if let Some(state) = states.get(¶m_id) {
141 return state.clone();
142 }
143 }
144
145 let mut states = self.parameter_states.write().unwrap();
147 if let Some(state) = states.get(¶m_id) {
149 return state.clone();
150 }
151
152 let new_state = Arc::new(Mutex::new(ParameterState::new(size)));
153 states.insert(param_id, new_state.clone());
154 new_state
155 }
156
157 pub fn step(&self) {
159 self.global_step.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
160 }
161
162 pub fn get_step(&self) -> usize {
164 self.global_step.load(std::sync::atomic::Ordering::Relaxed)
165 }
166
167 pub fn memory_usage(&self) -> StateMemoryStats {
169 let states = self.parameter_states.read().unwrap();
170 let mut total_momentum = 0;
171 let mut total_variance = 0;
172 let num_params = states.len();
173
174 for state_arc in states.values() {
175 if let Ok(state) = state_arc.try_lock() {
176 total_momentum += state.momentum.len();
177 total_variance += state.variance.len();
178 }
179 }
180
181 StateMemoryStats {
182 momentum_elements: total_momentum,
183 variance_elements: total_variance,
184 third_moment_elements: 0,
185 total_bytes: (total_momentum + total_variance) * std::mem::size_of::<f32>(),
186 num_parameters: num_params,
187 }
188 }
189
190 pub fn clear(&self) {
192 let mut states = self.parameter_states.write().unwrap();
193 states.clear();
194 self.global_step.store(0, std::sync::atomic::Ordering::Relaxed);
195 }
196}
197
198#[derive(Debug)]
200pub struct ParallelAdam {
201 lr: f32,
203 betas: (f32, f32),
205 eps: f32,
207 weight_decay: f32,
209 state: ParallelOptimizerState,
211}
212
213impl ParallelAdam {
214 pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
216 Self::with_config(lr, betas, eps, weight_decay, ParallelConfig::default())
217 }
218
219 pub fn with_config(
221 lr: f32,
222 betas: (f32, f32),
223 eps: f32,
224 weight_decay: f32,
225 config: ParallelConfig,
226 ) -> Self {
227 Self {
228 lr,
229 betas,
230 eps,
231 weight_decay,
232 state: ParallelOptimizerState::new(config),
233 }
234 }
235
236 pub fn update_parallel(&self, updates: Vec<(String, &mut [f32], &[f32])>) -> Result<()> {
238 let _chunk_size = self.state.config.chunk_size;
239 let min_params = self.state.config.min_params_per_thread;
240
241 if updates.len() < min_params || !self.should_parallelize(&updates) {
242 return self.update_sequential(updates);
244 }
245
246 let results: Result<Vec<()>> = updates
248 .into_par_iter()
249 .with_min_len(1)
250 .map(|(param_id, param, grad)| self.update_single_parameter(param_id, param, grad))
251 .collect();
252
253 results.map(|_| ())
254 }
255
256 fn update_sequential(&self, updates: Vec<(String, &mut [f32], &[f32])>) -> Result<()> {
258 for (param_id, param, grad) in updates {
259 self.update_single_parameter(param_id, param, grad)?;
260 }
261 Ok(())
262 }
263
264 fn update_single_parameter(
266 &self,
267 param_id: String,
268 param: &mut [f32],
269 grad: &[f32],
270 ) -> Result<()> {
271 if param.len() != grad.len() {
272 return Err(TrustformersError::tensor_op_error(
273 "Parameter and gradient size mismatch",
274 "update_single_parameter",
275 ));
276 }
277
278 let size = param.len();
279 let state_arc = self.state.get_or_create_state(param_id, size);
280 let chunk_size = self.state.config.chunk_size;
281
282 let mut param_state = state_arc.lock().expect("Parallel optimizer state lock poisoned");
284 param_state.step += 1;
285 param_state.last_update = std::time::Instant::now();
286
287 let step = param_state.step;
288 let (bias_correction1, bias_correction2) =
289 BiasCorrection::compute_adam_corrections(self.betas.0, self.betas.1, step);
290
291 let should_parallelize = size >= chunk_size * 2 && self.state.config.num_threads > 1;
293 if should_parallelize {
294 let ParameterState {
296 ref mut momentum,
297 ref mut variance,
298 ..
299 } = *param_state;
300 self.update_parameter_parallel(
301 param,
302 grad,
303 momentum,
304 variance,
305 bias_correction1,
306 bias_correction2,
307 chunk_size,
308 );
309 } else {
310 let ParameterState {
312 ref mut momentum,
313 ref mut variance,
314 ..
315 } = *param_state;
316 self.update_parameter_sequential(
317 param,
318 grad,
319 momentum,
320 variance,
321 bias_correction1,
322 bias_correction2,
323 );
324 }
325
326 Ok(())
327 }
328
329 fn update_parameter_parallel(
331 &self,
332 param: &mut [f32],
333 grad: &[f32],
334 momentum: &mut [f32],
335 variance: &mut [f32],
336 bias_correction1: f32,
337 bias_correction2: f32,
338 chunk_size: usize,
339 ) {
340 param
342 .par_chunks_mut(chunk_size)
343 .zip(grad.par_chunks(chunk_size))
344 .zip(momentum.par_chunks_mut(chunk_size))
345 .zip(variance.par_chunks_mut(chunk_size))
346 .for_each(|(((p_chunk, g_chunk), m_chunk), v_chunk)| {
347 self.process_chunk(
348 p_chunk,
349 g_chunk,
350 m_chunk,
351 v_chunk,
352 bias_correction1,
353 bias_correction2,
354 );
355 });
356 }
357
358 fn update_parameter_sequential(
360 &self,
361 param: &mut [f32],
362 grad: &[f32],
363 momentum: &mut [f32],
364 variance: &mut [f32],
365 bias_correction1: f32,
366 bias_correction2: f32,
367 ) {
368 self.process_chunk(
369 param,
370 grad,
371 momentum,
372 variance,
373 bias_correction1,
374 bias_correction2,
375 );
376 }
377
378 #[inline]
380 fn process_chunk(
381 &self,
382 param_chunk: &mut [f32],
383 grad_chunk: &[f32],
384 momentum_chunk: &mut [f32],
385 variance_chunk: &mut [f32],
386 bias_correction1: f32,
387 bias_correction2: f32,
388 ) {
389 let len = param_chunk
391 .len()
392 .min(grad_chunk.len())
393 .min(momentum_chunk.len())
394 .min(variance_chunk.len());
395
396 for i in 0..len {
397 let grad_val = grad_chunk[i] + self.weight_decay * param_chunk[i];
398
399 ParameterUpdate::update_ema(&mut momentum_chunk[i], grad_val, self.betas.0);
401 ParameterUpdate::update_ema(&mut variance_chunk[i], grad_val * grad_val, self.betas.1);
402
403 let m_hat = momentum_chunk[i] / bias_correction1;
405 let v_hat = variance_chunk[i] / bias_correction2;
406
407 ParameterUpdate::adam_update(&mut param_chunk[i], self.lr, m_hat, v_hat, self.eps);
408 }
409 }
410
411 fn should_parallelize(&self, updates: &[(String, &mut [f32], &[f32])]) -> bool {
413 let total_elements: usize = updates.iter().map(|(_, param, _)| param.len()).sum();
414 let num_threads = self.state.config.effective_num_threads();
415
416 total_elements >= self.state.config.min_params_per_thread * num_threads
417 }
418
419 pub fn parallel_stats(&self) -> ParallelStats {
421 let memory_stats = self.state.memory_usage();
422 let num_threads = self.state.config.effective_num_threads();
423
424 ParallelStats {
425 num_threads,
426 memory_stats,
427 config: self.state.config.clone(),
428 current_step: self.state.get_step(),
429 }
430 }
431
432 pub fn configure_thread_pool(&self) -> Result<()> {
434 let num_threads = self.state.config.effective_num_threads();
435
436 ThreadPoolBuilder::new().num_threads(num_threads).build_global().map_err(|e| {
437 TrustformersError::tensor_op_error(
438 &format!("Failed to configure thread pool: {}", e),
439 "configure_thread_pool",
440 )
441 })?;
442
443 Ok(())
444 }
445}
446
447impl Optimizer for ParallelAdam {
448 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
449 match (parameter, grad) {
450 (Tensor::F32(param), Tensor::F32(grad_arr)) => {
451 let param_id = format!("{:p}", param.as_ptr());
452 self.update_single_parameter(
453 param_id,
454 param.as_slice_mut().unwrap(),
455 grad_arr.as_slice().unwrap(),
456 )
457 },
458 _ => Err(TrustformersError::tensor_op_error(
459 "Unsupported tensor types for ParallelAdam",
460 "update",
461 )),
462 }
463 }
464
465 fn zero_grad(&mut self) {
466 }
468
469 fn step(&mut self) {
470 self.state.step();
471 }
472
473 fn get_lr(&self) -> f32 {
474 self.lr
475 }
476
477 fn set_lr(&mut self, lr: f32) {
478 self.lr = lr;
479 }
480}
481
482#[derive(Debug, Clone)]
484pub struct ParallelStats {
485 pub num_threads: usize,
487 pub memory_stats: StateMemoryStats,
489 pub config: ParallelConfig,
491 pub current_step: usize,
493}
494
495impl ParallelStats {
496 pub fn theoretical_speedup(&self, _sequential_time_ms: f64) -> f64 {
498 let parallel_fraction = 0.95; let num_threads = self.num_threads as f64;
501
502 1.0 / ((1.0 - parallel_fraction) + (parallel_fraction / num_threads))
503 }
504
505 pub fn optimization_suggestions(&self) -> Vec<String> {
507 let mut suggestions = Vec::new();
508
509 if self.num_threads == 1 {
510 suggestions.push(
511 "Consider increasing number of threads for better parallelization".to_string(),
512 );
513 }
514
515 if self.num_threads > num_cpus::get() {
516 suggestions.push("Number of threads exceeds CPU cores; consider reducing".to_string());
517 }
518
519 if self.config.chunk_size < 256 {
520 suggestions
521 .push("Small chunk size may cause overhead; consider increasing".to_string());
522 }
523
524 if self.config.chunk_size > 8192 {
525 suggestions.push("Large chunk size may reduce parallelization efficiency".to_string());
526 }
527
528 if !self.config.enable_work_stealing {
529 suggestions.push("Enable work stealing for better load balancing".to_string());
530 }
531
532 if suggestions.is_empty() {
533 suggestions.push("Parallel configuration appears optimal".to_string());
534 }
535
536 suggestions
537 }
538}
539
540pub trait BatchUpdate {
542 fn update_batch(&mut self, batch: Vec<(&mut Tensor, &Tensor)>) -> Result<()>;
544}
545
546impl BatchUpdate for ParallelAdam {
547 fn update_batch(&mut self, batch: Vec<(&mut Tensor, &Tensor)>) -> Result<()> {
548 let mut updates = Vec::new();
549
550 for (param, grad) in batch {
551 match (param, grad) {
552 (Tensor::F32(p), Tensor::F32(g)) => {
553 let param_id = format!("{:p}", p.as_ptr());
554 updates.push((param_id, p.as_slice_mut().unwrap(), g.as_slice().unwrap()));
555 },
556 _ => {
557 return Err(TrustformersError::tensor_op_error(
558 "Unsupported tensor types",
559 "update_batch",
560 ))
561 },
562 }
563 }
564
565 self.update_parallel(updates)
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn test_parallel_config() {
575 let config = ParallelConfig::default();
576 assert_eq!(config.num_threads, 0); assert!(config.enable_work_stealing);
578
579 let cpu_config = ParallelConfig::cpu_optimized();
580 assert_eq!(cpu_config.num_threads, num_cpus::get());
581
582 let effective_threads = config.effective_num_threads();
583 assert!(effective_threads > 0);
584 assert_eq!(effective_threads, num_cpus::get());
585 }
586
587 #[test]
588 fn test_parallel_optimizer_state() {
589 let config = ParallelConfig::default();
590 let state = ParallelOptimizerState::new(config);
591
592 assert_eq!(state.get_step(), 0);
593 state.step();
594 assert_eq!(state.get_step(), 1);
595
596 let param_state = state.get_or_create_state("test_param".to_string(), 100);
597 let locked_state = param_state.lock().expect("Parallel optimizer state lock poisoned");
598 assert_eq!(locked_state.momentum.len(), 100);
599 assert_eq!(locked_state.variance.len(), 100);
600 }
601
602 #[test]
603 fn test_parallel_adam() {
604 let optimizer = ParallelAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
605 assert_eq!(optimizer.get_lr(), 1e-3);
606 assert_eq!(optimizer.betas, (0.9, 0.999));
607
608 let stats = optimizer.parallel_stats();
609 assert!(stats.num_threads > 0);
610 assert_eq!(stats.current_step, 0);
611 }
612
613 #[test]
614 fn test_should_parallelize() {
615 let config = ParallelConfig {
616 min_params_per_thread: 1000,
617 num_threads: 4,
618 ..Default::default()
619 };
620 let optimizer = ParallelAdam::with_config(1e-3, (0.9, 0.999), 1e-8, 0.01, config);
621
622 let mut small_params = [0.0; 100];
624 let small_grads = [0.0; 100];
625 let small_updates = vec![(
626 "param1".to_string(),
627 &mut small_params[..],
628 &small_grads[..],
629 )];
630 assert!(!optimizer.should_parallelize(&small_updates));
631
632 let mut large_params = [0.0; 5000];
634 let large_grads = [0.0; 5000];
635 let large_updates = vec![(
636 "param1".to_string(),
637 &mut large_params[..],
638 &large_grads[..],
639 )];
640 assert!(optimizer.should_parallelize(&large_updates));
641 }
642
643 #[test]
644 fn test_parallel_stats() {
645 let optimizer = ParallelAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
646 let stats = optimizer.parallel_stats();
647
648 let speedup = stats.theoretical_speedup(1000.0);
649 assert!(speedup > 1.0);
650 assert!(speedup <= stats.num_threads as f64);
651
652 let suggestions = stats.optimization_suggestions();
653 assert!(!suggestions.is_empty());
654 }
655
656 #[test]
657 fn test_memory_usage() {
658 let config = ParallelConfig::default();
659 let state = ParallelOptimizerState::new(config);
660
661 state.get_or_create_state("param1".to_string(), 1000);
663 state.get_or_create_state("param2".to_string(), 2000);
664
665 let memory_stats = state.memory_usage();
666 assert_eq!(memory_stats.num_parameters, 2);
667 assert_eq!(memory_stats.momentum_elements, 3000);
668 assert_eq!(memory_stats.variance_elements, 3000);
669 }
670}