1use crate::common::{BiasCorrection, ParameterUpdate, StateMemoryStats};
15use scirs2_core::parallel_ops::*; use std::collections::HashMap;
17use std::sync::{Arc, Mutex, RwLock};
18use trustformers_core::errors::{Result, TrustformersError};
19use trustformers_core::tensor::Tensor;
20use trustformers_core::traits::Optimizer;
21
22#[derive(Debug, Clone)]
24pub struct ParallelConfig {
25 pub num_threads: usize,
27 pub min_params_per_thread: usize,
29 pub enable_work_stealing: bool,
31 pub numa_aware: bool,
33 pub chunk_size: usize,
35 pub lock_free: bool,
37}
38
39impl Default for ParallelConfig {
40 fn default() -> Self {
41 Self {
42 num_threads: 0, min_params_per_thread: 1000,
44 enable_work_stealing: true,
45 numa_aware: false,
46 chunk_size: 1024,
47 lock_free: true,
48 }
49 }
50}
51
52impl ParallelConfig {
53 pub fn cpu_optimized() -> Self {
55 Self {
56 num_threads: num_cpus::get(),
57 chunk_size: 512,
58 enable_work_stealing: true,
59 ..Default::default()
60 }
61 }
62
63 pub fn large_model() -> Self {
65 Self {
66 num_threads: num_cpus::get(),
67 min_params_per_thread: 10000,
68 chunk_size: 4096,
69 numa_aware: true,
70 ..Default::default()
71 }
72 }
73
74 pub fn memory_bound() -> Self {
76 Self {
77 num_threads: (num_cpus::get() / 2).max(1),
78 chunk_size: 2048,
79 numa_aware: true,
80 ..Default::default()
81 }
82 }
83
84 pub fn effective_num_threads(&self) -> usize {
86 if self.num_threads == 0 {
87 num_cpus::get()
88 } else {
89 self.num_threads
90 }
91 }
92}
93
94#[derive(Debug)]
96pub struct ParallelOptimizerState {
97 parameter_states: RwLock<HashMap<String, Arc<Mutex<ParameterState>>>>,
99 global_step: Arc<std::sync::atomic::AtomicUsize>,
101 config: ParallelConfig,
103}
104
105#[derive(Debug)]
107pub struct ParameterState {
108 pub momentum: Vec<f32>,
109 pub variance: Vec<f32>,
110 pub step: usize,
111 pub last_update: std::time::Instant,
112}
113
114impl ParameterState {
115 fn new(size: usize) -> Self {
116 Self {
117 momentum: vec![0.0; size],
118 variance: vec![0.0; size],
119 step: 0,
120 last_update: std::time::Instant::now(),
121 }
122 }
123}
124
125impl ParallelOptimizerState {
126 pub fn new(config: ParallelConfig) -> Self {
128 Self {
129 parameter_states: RwLock::new(HashMap::new()),
130 global_step: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
131 config,
132 }
133 }
134
135 pub fn get_or_create_state(&self, param_id: String, size: usize) -> Arc<Mutex<ParameterState>> {
137 {
139 let states = self
140 .parameter_states
141 .read()
142 .expect("parameter_states lock should not be poisoned");
143 if let Some(state) = states.get(¶m_id) {
144 return state.clone();
145 }
146 }
147
148 let mut states = self
150 .parameter_states
151 .write()
152 .expect("parameter_states lock should not be poisoned");
153 if let Some(state) = states.get(¶m_id) {
155 return state.clone();
156 }
157
158 let new_state = Arc::new(Mutex::new(ParameterState::new(size)));
159 states.insert(param_id, new_state.clone());
160 new_state
161 }
162
163 pub fn step(&self) {
165 self.global_step.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
166 }
167
168 pub fn get_step(&self) -> usize {
170 self.global_step.load(std::sync::atomic::Ordering::Relaxed)
171 }
172
173 pub fn memory_usage(&self) -> StateMemoryStats {
175 let states = self
176 .parameter_states
177 .read()
178 .expect("parameter_states lock should not be poisoned");
179 let mut total_momentum = 0;
180 let mut total_variance = 0;
181 let num_params = states.len();
182
183 for state_arc in states.values() {
184 if let Ok(state) = state_arc.try_lock() {
185 total_momentum += state.momentum.len();
186 total_variance += state.variance.len();
187 }
188 }
189
190 StateMemoryStats {
191 momentum_elements: total_momentum,
192 variance_elements: total_variance,
193 third_moment_elements: 0,
194 total_bytes: (total_momentum + total_variance) * std::mem::size_of::<f32>(),
195 num_parameters: num_params,
196 }
197 }
198
199 pub fn clear(&self) {
201 let mut states = self
202 .parameter_states
203 .write()
204 .expect("parameter_states lock should not be poisoned");
205 states.clear();
206 self.global_step.store(0, std::sync::atomic::Ordering::Relaxed);
207 }
208}
209
210#[derive(Debug)]
212pub struct ParallelAdam {
213 lr: f32,
215 betas: (f32, f32),
217 eps: f32,
219 weight_decay: f32,
221 state: ParallelOptimizerState,
223}
224
225impl ParallelAdam {
226 pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
228 Self::with_config(lr, betas, eps, weight_decay, ParallelConfig::default())
229 }
230
231 pub fn with_config(
233 lr: f32,
234 betas: (f32, f32),
235 eps: f32,
236 weight_decay: f32,
237 config: ParallelConfig,
238 ) -> Self {
239 Self {
240 lr,
241 betas,
242 eps,
243 weight_decay,
244 state: ParallelOptimizerState::new(config),
245 }
246 }
247
248 pub fn update_parallel(&self, updates: Vec<(String, &mut [f32], &[f32])>) -> Result<()> {
250 let _chunk_size = self.state.config.chunk_size;
251 let min_params = self.state.config.min_params_per_thread;
252
253 if updates.len() < min_params || !self.should_parallelize(&updates) {
254 return self.update_sequential(updates);
256 }
257
258 let results: Result<Vec<()>> = updates
260 .into_par_iter()
261 .with_min_len(1)
262 .map(|(param_id, param, grad)| self.update_single_parameter(param_id, param, grad))
263 .collect();
264
265 results.map(|_| ())
266 }
267
268 fn update_sequential(&self, updates: Vec<(String, &mut [f32], &[f32])>) -> Result<()> {
270 for (param_id, param, grad) in updates {
271 self.update_single_parameter(param_id, param, grad)?;
272 }
273 Ok(())
274 }
275
276 fn update_single_parameter(
278 &self,
279 param_id: String,
280 param: &mut [f32],
281 grad: &[f32],
282 ) -> Result<()> {
283 if param.len() != grad.len() {
284 return Err(TrustformersError::tensor_op_error(
285 "Parameter and gradient size mismatch",
286 "update_single_parameter",
287 ));
288 }
289
290 let size = param.len();
291 let state_arc = self.state.get_or_create_state(param_id, size);
292 let chunk_size = self.state.config.chunk_size;
293
294 let mut param_state = state_arc.lock().expect("Parallel optimizer state lock poisoned");
296 param_state.step += 1;
297 param_state.last_update = std::time::Instant::now();
298
299 let step = param_state.step;
300 let (bias_correction1, bias_correction2) =
301 BiasCorrection::compute_adam_corrections(self.betas.0, self.betas.1, step);
302
303 let should_parallelize = size >= chunk_size * 2 && self.state.config.num_threads > 1;
305 if should_parallelize {
306 let ParameterState {
308 ref mut momentum,
309 ref mut variance,
310 ..
311 } = *param_state;
312 self.update_parameter_parallel(
313 param,
314 grad,
315 momentum,
316 variance,
317 bias_correction1,
318 bias_correction2,
319 chunk_size,
320 );
321 } else {
322 let ParameterState {
324 ref mut momentum,
325 ref mut variance,
326 ..
327 } = *param_state;
328 self.update_parameter_sequential(
329 param,
330 grad,
331 momentum,
332 variance,
333 bias_correction1,
334 bias_correction2,
335 );
336 }
337
338 Ok(())
339 }
340
341 fn update_parameter_parallel(
343 &self,
344 param: &mut [f32],
345 grad: &[f32],
346 momentum: &mut [f32],
347 variance: &mut [f32],
348 bias_correction1: f32,
349 bias_correction2: f32,
350 chunk_size: usize,
351 ) {
352 param
354 .par_chunks_mut(chunk_size)
355 .zip(grad.par_chunks(chunk_size))
356 .zip(momentum.par_chunks_mut(chunk_size))
357 .zip(variance.par_chunks_mut(chunk_size))
358 .for_each(|(((p_chunk, g_chunk), m_chunk), v_chunk)| {
359 self.process_chunk(
360 p_chunk,
361 g_chunk,
362 m_chunk,
363 v_chunk,
364 bias_correction1,
365 bias_correction2,
366 );
367 });
368 }
369
370 fn update_parameter_sequential(
372 &self,
373 param: &mut [f32],
374 grad: &[f32],
375 momentum: &mut [f32],
376 variance: &mut [f32],
377 bias_correction1: f32,
378 bias_correction2: f32,
379 ) {
380 self.process_chunk(
381 param,
382 grad,
383 momentum,
384 variance,
385 bias_correction1,
386 bias_correction2,
387 );
388 }
389
390 #[inline]
392 fn process_chunk(
393 &self,
394 param_chunk: &mut [f32],
395 grad_chunk: &[f32],
396 momentum_chunk: &mut [f32],
397 variance_chunk: &mut [f32],
398 bias_correction1: f32,
399 bias_correction2: f32,
400 ) {
401 let len = param_chunk
403 .len()
404 .min(grad_chunk.len())
405 .min(momentum_chunk.len())
406 .min(variance_chunk.len());
407
408 for i in 0..len {
409 let grad_val = grad_chunk[i] + self.weight_decay * param_chunk[i];
410
411 ParameterUpdate::update_ema(&mut momentum_chunk[i], grad_val, self.betas.0);
413 ParameterUpdate::update_ema(&mut variance_chunk[i], grad_val * grad_val, self.betas.1);
414
415 let m_hat = momentum_chunk[i] / bias_correction1;
417 let v_hat = variance_chunk[i] / bias_correction2;
418
419 ParameterUpdate::adam_update(&mut param_chunk[i], self.lr, m_hat, v_hat, self.eps);
420 }
421 }
422
423 fn should_parallelize(&self, updates: &[(String, &mut [f32], &[f32])]) -> bool {
425 let total_elements: usize = updates.iter().map(|(_, param, _)| param.len()).sum();
426 let num_threads = self.state.config.effective_num_threads();
427
428 total_elements >= self.state.config.min_params_per_thread * num_threads
429 }
430
431 pub fn parallel_stats(&self) -> ParallelStats {
433 let memory_stats = self.state.memory_usage();
434 let num_threads = self.state.config.effective_num_threads();
435
436 ParallelStats {
437 num_threads,
438 memory_stats,
439 config: self.state.config.clone(),
440 current_step: self.state.get_step(),
441 }
442 }
443
444 pub fn configure_thread_pool(&self) -> Result<()> {
446 let num_threads = self.state.config.effective_num_threads();
447
448 ThreadPoolBuilder::new().num_threads(num_threads).build_global().map_err(|e| {
449 TrustformersError::tensor_op_error(
450 &format!("Failed to configure thread pool: {}", e),
451 "configure_thread_pool",
452 )
453 })?;
454
455 Ok(())
456 }
457}
458
459impl Optimizer for ParallelAdam {
460 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
461 match (parameter, grad) {
462 (Tensor::F32(param), Tensor::F32(grad_arr)) => {
463 let param_id = format!("{:p}", param.as_ptr());
464 self.update_single_parameter(
465 param_id,
466 param.as_slice_mut().expect("array must have contiguous layout"),
467 grad_arr.as_slice().expect("array must have contiguous layout"),
468 )
469 },
470 _ => Err(TrustformersError::tensor_op_error(
471 "Unsupported tensor types for ParallelAdam",
472 "update",
473 )),
474 }
475 }
476
477 fn zero_grad(&mut self) {
478 }
480
481 fn step(&mut self) {
482 self.state.step();
483 }
484
485 fn get_lr(&self) -> f32 {
486 self.lr
487 }
488
489 fn set_lr(&mut self, lr: f32) {
490 self.lr = lr;
491 }
492}
493
494#[derive(Debug, Clone)]
496pub struct ParallelStats {
497 pub num_threads: usize,
499 pub memory_stats: StateMemoryStats,
501 pub config: ParallelConfig,
503 pub current_step: usize,
505}
506
507impl ParallelStats {
508 pub fn theoretical_speedup(&self, _sequential_time_ms: f64) -> f64 {
510 let parallel_fraction = 0.95; let num_threads = self.num_threads as f64;
513
514 1.0 / ((1.0 - parallel_fraction) + (parallel_fraction / num_threads))
515 }
516
517 pub fn optimization_suggestions(&self) -> Vec<String> {
519 let mut suggestions = Vec::new();
520
521 if self.num_threads == 1 {
522 suggestions.push(
523 "Consider increasing number of threads for better parallelization".to_string(),
524 );
525 }
526
527 if self.num_threads > num_cpus::get() {
528 suggestions.push("Number of threads exceeds CPU cores; consider reducing".to_string());
529 }
530
531 if self.config.chunk_size < 256 {
532 suggestions
533 .push("Small chunk size may cause overhead; consider increasing".to_string());
534 }
535
536 if self.config.chunk_size > 8192 {
537 suggestions.push("Large chunk size may reduce parallelization efficiency".to_string());
538 }
539
540 if !self.config.enable_work_stealing {
541 suggestions.push("Enable work stealing for better load balancing".to_string());
542 }
543
544 if suggestions.is_empty() {
545 suggestions.push("Parallel configuration appears optimal".to_string());
546 }
547
548 suggestions
549 }
550}
551
552pub trait BatchUpdate {
554 fn update_batch(&mut self, batch: Vec<(&mut Tensor, &Tensor)>) -> Result<()>;
556}
557
558impl BatchUpdate for ParallelAdam {
559 fn update_batch(&mut self, batch: Vec<(&mut Tensor, &Tensor)>) -> Result<()> {
560 let mut updates = Vec::new();
561
562 for (param, grad) in batch {
563 match (param, grad) {
564 (Tensor::F32(p), Tensor::F32(g)) => {
565 let param_id = format!("{:p}", p.as_ptr());
566 updates.push((
567 param_id,
568 p.as_slice_mut().expect("array must have contiguous layout"),
569 g.as_slice().expect("array must have contiguous layout"),
570 ));
571 },
572 _ => {
573 return Err(TrustformersError::tensor_op_error(
574 "Unsupported tensor types",
575 "update_batch",
576 ))
577 },
578 }
579 }
580
581 self.update_parallel(updates)
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_parallel_config() {
591 let config = ParallelConfig::default();
592 assert_eq!(config.num_threads, 0); assert!(config.enable_work_stealing);
594
595 let cpu_config = ParallelConfig::cpu_optimized();
596 assert_eq!(cpu_config.num_threads, num_cpus::get());
597
598 let effective_threads = config.effective_num_threads();
599 assert!(effective_threads > 0);
600 assert_eq!(effective_threads, num_cpus::get());
601 }
602
603 #[test]
604 fn test_parallel_optimizer_state() {
605 let config = ParallelConfig::default();
606 let state = ParallelOptimizerState::new(config);
607
608 assert_eq!(state.get_step(), 0);
609 state.step();
610 assert_eq!(state.get_step(), 1);
611
612 let param_state = state.get_or_create_state("test_param".to_string(), 100);
613 let locked_state = param_state.lock().expect("Parallel optimizer state lock poisoned");
614 assert_eq!(locked_state.momentum.len(), 100);
615 assert_eq!(locked_state.variance.len(), 100);
616 }
617
618 #[test]
619 fn test_parallel_adam() {
620 let optimizer = ParallelAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
621 assert_eq!(optimizer.get_lr(), 1e-3);
622 assert_eq!(optimizer.betas, (0.9, 0.999));
623
624 let stats = optimizer.parallel_stats();
625 assert!(stats.num_threads > 0);
626 assert_eq!(stats.current_step, 0);
627 }
628
629 #[test]
630 fn test_should_parallelize() {
631 let config = ParallelConfig {
632 min_params_per_thread: 1000,
633 num_threads: 4,
634 ..Default::default()
635 };
636 let optimizer = ParallelAdam::with_config(1e-3, (0.9, 0.999), 1e-8, 0.01, config);
637
638 let mut small_params = [0.0; 100];
640 let small_grads = [0.0; 100];
641 let small_updates = vec![(
642 "param1".to_string(),
643 &mut small_params[..],
644 &small_grads[..],
645 )];
646 assert!(!optimizer.should_parallelize(&small_updates));
647
648 let mut large_params = [0.0; 5000];
650 let large_grads = [0.0; 5000];
651 let large_updates = vec![(
652 "param1".to_string(),
653 &mut large_params[..],
654 &large_grads[..],
655 )];
656 assert!(optimizer.should_parallelize(&large_updates));
657 }
658
659 #[test]
660 fn test_parallel_stats() {
661 let optimizer = ParallelAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
662 let stats = optimizer.parallel_stats();
663
664 let speedup = stats.theoretical_speedup(1000.0);
665 assert!(speedup > 1.0);
666 assert!(speedup <= stats.num_threads as f64);
667
668 let suggestions = stats.optimization_suggestions();
669 assert!(!suggestions.is_empty());
670 }
671
672 #[test]
673 fn test_memory_usage() {
674 let config = ParallelConfig::default();
675 let state = ParallelOptimizerState::new(config);
676
677 state.get_or_create_state("param1".to_string(), 1000);
679 state.get_or_create_state("param2".to_string(), 2000);
680
681 let memory_stats = state.memory_usage();
682 assert_eq!(memory_stats.num_parameters, 2);
683 assert_eq!(memory_stats.momentum_elements, 3000);
684 assert_eq!(memory_stats.variance_elements, 3000);
685 }
686}