Skip to main content

trustformers_optim/
parallel.rs

1//! Parallel optimization algorithms for multi-threaded training.
2//!
3//! This module provides thread-safe optimizers that can leverage multiple CPU cores
4//! for parallel parameter updates, improving performance on multi-core systems.
5//!
6//! # Key Features
7//!
8//! - **Thread-Safe State Management**: Lock-free and fine-grained locking strategies
9//! - **Parallel Parameter Updates**: Distribute parameter updates across threads
10//! - **Work Stealing**: Dynamic load balancing for uneven parameter distributions
11//! - **NUMA Awareness**: Optimize for Non-Uniform Memory Access architectures
12//! - **Scalability**: Efficient scaling from 2 to 64+ cores
13
14use crate::common::{BiasCorrection, ParameterUpdate, StateMemoryStats};
15use scirs2_core::parallel_ops::*; // SciRS2 Integration Policy - replaces rayon
16use std::collections::HashMap;
17use std::sync::{Arc, Mutex, RwLock};
18use trustformers_core::errors::{Result, TrustformersError};
19use trustformers_core::tensor::Tensor;
20use trustformers_core::traits::Optimizer;
21
22/// Configuration for parallel optimization.
23#[derive(Debug, Clone)]
24pub struct ParallelConfig {
25    /// Number of worker threads (0 = auto-detect)
26    pub num_threads: usize,
27    /// Minimum parameters per thread to justify parallelization
28    pub min_params_per_thread: usize,
29    /// Enable work stealing for load balancing
30    pub enable_work_stealing: bool,
31    /// Enable NUMA-aware thread pinning
32    pub numa_aware: bool,
33    /// Chunk size for parameter processing
34    pub chunk_size: usize,
35    /// Enable lock-free optimizations where possible
36    pub lock_free: bool,
37}
38
39impl Default for ParallelConfig {
40    fn default() -> Self {
41        Self {
42            num_threads: 0, // Auto-detect
43            min_params_per_thread: 1000,
44            enable_work_stealing: true,
45            numa_aware: false,
46            chunk_size: 1024,
47            lock_free: true,
48        }
49    }
50}
51
52impl ParallelConfig {
53    /// Creates configuration optimized for CPU-bound workloads.
54    pub fn cpu_optimized() -> Self {
55        Self {
56            num_threads: num_cpus::get(),
57            chunk_size: 512,
58            enable_work_stealing: true,
59            ..Default::default()
60        }
61    }
62
63    /// Creates configuration for large model training.
64    pub fn large_model() -> Self {
65        Self {
66            num_threads: num_cpus::get(),
67            min_params_per_thread: 10000,
68            chunk_size: 4096,
69            numa_aware: true,
70            ..Default::default()
71        }
72    }
73
74    /// Creates configuration for memory-bound workloads.
75    pub fn memory_bound() -> Self {
76        Self {
77            num_threads: (num_cpus::get() / 2).max(1),
78            chunk_size: 2048,
79            numa_aware: true,
80            ..Default::default()
81        }
82    }
83
84    /// Gets the effective number of threads.
85    pub fn effective_num_threads(&self) -> usize {
86        if self.num_threads == 0 {
87            num_cpus::get()
88        } else {
89            self.num_threads
90        }
91    }
92}
93
94/// Thread-safe optimizer state with fine-grained locking.
95#[derive(Debug)]
96pub struct ParallelOptimizerState {
97    /// Per-parameter state with individual locks
98    parameter_states: RwLock<HashMap<String, Arc<Mutex<ParameterState>>>>,
99    /// Global step counter
100    global_step: Arc<std::sync::atomic::AtomicUsize>,
101    /// Parallel configuration
102    config: ParallelConfig,
103}
104
105/// Individual parameter state with momentum and variance.
106#[derive(Debug)]
107pub struct ParameterState {
108    pub momentum: Vec<f32>,
109    pub variance: Vec<f32>,
110    pub step: usize,
111    pub last_update: std::time::Instant,
112}
113
114impl ParameterState {
115    fn new(size: usize) -> Self {
116        Self {
117            momentum: vec![0.0; size],
118            variance: vec![0.0; size],
119            step: 0,
120            last_update: std::time::Instant::now(),
121        }
122    }
123}
124
125impl ParallelOptimizerState {
126    /// Creates a new parallel optimizer state.
127    pub fn new(config: ParallelConfig) -> Self {
128        Self {
129            parameter_states: RwLock::new(HashMap::new()),
130            global_step: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
131            config,
132        }
133    }
134
135    /// Gets or creates parameter state.
136    pub fn get_or_create_state(&self, param_id: String, size: usize) -> Arc<Mutex<ParameterState>> {
137        // Try read-only access first
138        {
139            let states = self
140                .parameter_states
141                .read()
142                .expect("parameter_states lock should not be poisoned");
143            if let Some(state) = states.get(&param_id) {
144                return state.clone();
145            }
146        }
147
148        // Need to create new state - upgrade to write lock
149        let mut states = self
150            .parameter_states
151            .write()
152            .expect("parameter_states lock should not be poisoned");
153        // Double-check pattern in case another thread created it
154        if let Some(state) = states.get(&param_id) {
155            return state.clone();
156        }
157
158        let new_state = Arc::new(Mutex::new(ParameterState::new(size)));
159        states.insert(param_id, new_state.clone());
160        new_state
161    }
162
163    /// Increments global step counter atomically.
164    pub fn step(&self) {
165        self.global_step.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
166    }
167
168    /// Gets current global step.
169    pub fn get_step(&self) -> usize {
170        self.global_step.load(std::sync::atomic::Ordering::Relaxed)
171    }
172
173    /// Gets memory usage statistics.
174    pub fn memory_usage(&self) -> StateMemoryStats {
175        let states = self
176            .parameter_states
177            .read()
178            .expect("parameter_states lock should not be poisoned");
179        let mut total_momentum = 0;
180        let mut total_variance = 0;
181        let num_params = states.len();
182
183        for state_arc in states.values() {
184            if let Ok(state) = state_arc.try_lock() {
185                total_momentum += state.momentum.len();
186                total_variance += state.variance.len();
187            }
188        }
189
190        StateMemoryStats {
191            momentum_elements: total_momentum,
192            variance_elements: total_variance,
193            third_moment_elements: 0,
194            total_bytes: (total_momentum + total_variance) * std::mem::size_of::<f32>(),
195            num_parameters: num_params,
196        }
197    }
198
199    /// Clears all parameter states.
200    pub fn clear(&self) {
201        let mut states = self
202            .parameter_states
203            .write()
204            .expect("parameter_states lock should not be poisoned");
205        states.clear();
206        self.global_step.store(0, std::sync::atomic::Ordering::Relaxed);
207    }
208}
209
210/// Parallel Adam optimizer with multi-threaded parameter updates.
211#[derive(Debug)]
212pub struct ParallelAdam {
213    /// Learning rate
214    lr: f32,
215    /// Beta coefficients
216    betas: (f32, f32),
217    /// Epsilon for numerical stability
218    eps: f32,
219    /// Weight decay coefficient
220    weight_decay: f32,
221    /// Parallel optimizer state
222    state: ParallelOptimizerState,
223}
224
225impl ParallelAdam {
226    /// Creates a new parallel Adam optimizer.
227    pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
228        Self::with_config(lr, betas, eps, weight_decay, ParallelConfig::default())
229    }
230
231    /// Creates a parallel Adam optimizer with custom configuration.
232    pub fn with_config(
233        lr: f32,
234        betas: (f32, f32),
235        eps: f32,
236        weight_decay: f32,
237        config: ParallelConfig,
238    ) -> Self {
239        Self {
240            lr,
241            betas,
242            eps,
243            weight_decay,
244            state: ParallelOptimizerState::new(config),
245        }
246    }
247
248    /// Updates multiple parameters in parallel.
249    pub fn update_parallel(&self, updates: Vec<(String, &mut [f32], &[f32])>) -> Result<()> {
250        let _chunk_size = self.state.config.chunk_size;
251        let min_params = self.state.config.min_params_per_thread;
252
253        if updates.len() < min_params || !self.should_parallelize(&updates) {
254            // Use sequential processing for small workloads
255            return self.update_sequential(updates);
256        }
257
258        // Parallel processing using rayon
259        let results: Result<Vec<()>> = updates
260            .into_par_iter()
261            .with_min_len(1)
262            .map(|(param_id, param, grad)| self.update_single_parameter(param_id, param, grad))
263            .collect();
264
265        results.map(|_| ())
266    }
267
268    /// Updates parameters sequentially.
269    fn update_sequential(&self, updates: Vec<(String, &mut [f32], &[f32])>) -> Result<()> {
270        for (param_id, param, grad) in updates {
271            self.update_single_parameter(param_id, param, grad)?;
272        }
273        Ok(())
274    }
275
276    /// Updates a single parameter with parallel chunk processing.
277    fn update_single_parameter(
278        &self,
279        param_id: String,
280        param: &mut [f32],
281        grad: &[f32],
282    ) -> Result<()> {
283        if param.len() != grad.len() {
284            return Err(TrustformersError::tensor_op_error(
285                "Parameter and gradient size mismatch",
286                "update_single_parameter",
287            ));
288        }
289
290        let size = param.len();
291        let state_arc = self.state.get_or_create_state(param_id, size);
292        let chunk_size = self.state.config.chunk_size;
293
294        // Lock the parameter state
295        let mut param_state = state_arc.lock().expect("Parallel optimizer state lock poisoned");
296        param_state.step += 1;
297        param_state.last_update = std::time::Instant::now();
298
299        let step = param_state.step;
300        let (bias_correction1, bias_correction2) =
301            BiasCorrection::compute_adam_corrections(self.betas.0, self.betas.1, step);
302
303        // Determine if we should parallelize this parameter
304        let should_parallelize = size >= chunk_size * 2 && self.state.config.num_threads > 1;
305        if should_parallelize {
306            // Parallel chunk processing
307            let ParameterState {
308                ref mut momentum,
309                ref mut variance,
310                ..
311            } = *param_state;
312            self.update_parameter_parallel(
313                param,
314                grad,
315                momentum,
316                variance,
317                bias_correction1,
318                bias_correction2,
319                chunk_size,
320            );
321        } else {
322            // Sequential processing for small parameters
323            let ParameterState {
324                ref mut momentum,
325                ref mut variance,
326                ..
327            } = *param_state;
328            self.update_parameter_sequential(
329                param,
330                grad,
331                momentum,
332                variance,
333                bias_correction1,
334                bias_correction2,
335            );
336        }
337
338        Ok(())
339    }
340
341    /// Updates parameter using parallel chunk processing.
342    fn update_parameter_parallel(
343        &self,
344        param: &mut [f32],
345        grad: &[f32],
346        momentum: &mut [f32],
347        variance: &mut [f32],
348        bias_correction1: f32,
349        bias_correction2: f32,
350        chunk_size: usize,
351    ) {
352        // Use parallel iterators for chunk-based processing
353        param
354            .par_chunks_mut(chunk_size)
355            .zip(grad.par_chunks(chunk_size))
356            .zip(momentum.par_chunks_mut(chunk_size))
357            .zip(variance.par_chunks_mut(chunk_size))
358            .for_each(|(((p_chunk, g_chunk), m_chunk), v_chunk)| {
359                self.process_chunk(
360                    p_chunk,
361                    g_chunk,
362                    m_chunk,
363                    v_chunk,
364                    bias_correction1,
365                    bias_correction2,
366                );
367            });
368    }
369
370    /// Updates parameter sequentially.
371    fn update_parameter_sequential(
372        &self,
373        param: &mut [f32],
374        grad: &[f32],
375        momentum: &mut [f32],
376        variance: &mut [f32],
377        bias_correction1: f32,
378        bias_correction2: f32,
379    ) {
380        self.process_chunk(
381            param,
382            grad,
383            momentum,
384            variance,
385            bias_correction1,
386            bias_correction2,
387        );
388    }
389
390    /// Processes a chunk of parameters.
391    #[inline]
392    fn process_chunk(
393        &self,
394        param_chunk: &mut [f32],
395        grad_chunk: &[f32],
396        momentum_chunk: &mut [f32],
397        variance_chunk: &mut [f32],
398        bias_correction1: f32,
399        bias_correction2: f32,
400    ) {
401        // Use the minimum length to avoid index out of bounds
402        let len = param_chunk
403            .len()
404            .min(grad_chunk.len())
405            .min(momentum_chunk.len())
406            .min(variance_chunk.len());
407
408        for i in 0..len {
409            let grad_val = grad_chunk[i] + self.weight_decay * param_chunk[i];
410
411            // Update momentum and variance
412            ParameterUpdate::update_ema(&mut momentum_chunk[i], grad_val, self.betas.0);
413            ParameterUpdate::update_ema(&mut variance_chunk[i], grad_val * grad_val, self.betas.1);
414
415            // Apply bias-corrected update
416            let m_hat = momentum_chunk[i] / bias_correction1;
417            let v_hat = variance_chunk[i] / bias_correction2;
418
419            ParameterUpdate::adam_update(&mut param_chunk[i], self.lr, m_hat, v_hat, self.eps);
420        }
421    }
422
423    /// Determines if parallelization should be used based on workload.
424    fn should_parallelize(&self, updates: &[(String, &mut [f32], &[f32])]) -> bool {
425        let total_elements: usize = updates.iter().map(|(_, param, _)| param.len()).sum();
426        let num_threads = self.state.config.effective_num_threads();
427
428        total_elements >= self.state.config.min_params_per_thread * num_threads
429    }
430
431    /// Gets parallel performance statistics.
432    pub fn parallel_stats(&self) -> ParallelStats {
433        let memory_stats = self.state.memory_usage();
434        let num_threads = self.state.config.effective_num_threads();
435
436        ParallelStats {
437            num_threads,
438            memory_stats,
439            config: self.state.config.clone(),
440            current_step: self.state.get_step(),
441        }
442    }
443
444    /// Configures thread pool for optimal performance.
445    pub fn configure_thread_pool(&self) -> Result<()> {
446        let num_threads = self.state.config.effective_num_threads();
447
448        ThreadPoolBuilder::new().num_threads(num_threads).build_global().map_err(|e| {
449            TrustformersError::tensor_op_error(
450                &format!("Failed to configure thread pool: {}", e),
451                "configure_thread_pool",
452            )
453        })?;
454
455        Ok(())
456    }
457}
458
459impl Optimizer for ParallelAdam {
460    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
461        match (parameter, grad) {
462            (Tensor::F32(param), Tensor::F32(grad_arr)) => {
463                let param_id = format!("{:p}", param.as_ptr());
464                self.update_single_parameter(
465                    param_id,
466                    param.as_slice_mut().expect("array must have contiguous layout"),
467                    grad_arr.as_slice().expect("array must have contiguous layout"),
468                )
469            },
470            _ => Err(TrustformersError::tensor_op_error(
471                "Unsupported tensor types for ParallelAdam",
472                "update",
473            )),
474        }
475    }
476
477    fn zero_grad(&mut self) {
478        // No explicit gradient storage
479    }
480
481    fn step(&mut self) {
482        self.state.step();
483    }
484
485    fn get_lr(&self) -> f32 {
486        self.lr
487    }
488
489    fn set_lr(&mut self, lr: f32) {
490        self.lr = lr;
491    }
492}
493
494/// Performance statistics for parallel optimization.
495#[derive(Debug, Clone)]
496pub struct ParallelStats {
497    /// Number of worker threads
498    pub num_threads: usize,
499    /// Memory usage statistics
500    pub memory_stats: StateMemoryStats,
501    /// Parallel configuration
502    pub config: ParallelConfig,
503    /// Current optimization step
504    pub current_step: usize,
505}
506
507impl ParallelStats {
508    /// Calculates theoretical speedup based on workload.
509    pub fn theoretical_speedup(&self, _sequential_time_ms: f64) -> f64 {
510        // Simple Amdahl's law approximation
511        let parallel_fraction = 0.95; // Assume 95% of work can be parallelized
512        let num_threads = self.num_threads as f64;
513
514        1.0 / ((1.0 - parallel_fraction) + (parallel_fraction / num_threads))
515    }
516
517    /// Suggests optimization improvements.
518    pub fn optimization_suggestions(&self) -> Vec<String> {
519        let mut suggestions = Vec::new();
520
521        if self.num_threads == 1 {
522            suggestions.push(
523                "Consider increasing number of threads for better parallelization".to_string(),
524            );
525        }
526
527        if self.num_threads > num_cpus::get() {
528            suggestions.push("Number of threads exceeds CPU cores; consider reducing".to_string());
529        }
530
531        if self.config.chunk_size < 256 {
532            suggestions
533                .push("Small chunk size may cause overhead; consider increasing".to_string());
534        }
535
536        if self.config.chunk_size > 8192 {
537            suggestions.push("Large chunk size may reduce parallelization efficiency".to_string());
538        }
539
540        if !self.config.enable_work_stealing {
541            suggestions.push("Enable work stealing for better load balancing".to_string());
542        }
543
544        if suggestions.is_empty() {
545            suggestions.push("Parallel configuration appears optimal".to_string());
546        }
547
548        suggestions
549    }
550}
551
552/// Batch parameter update interface for better parallelization.
553pub trait BatchUpdate {
554    /// Updates multiple parameters in a single batch operation.
555    fn update_batch(&mut self, batch: Vec<(&mut Tensor, &Tensor)>) -> Result<()>;
556}
557
558impl BatchUpdate for ParallelAdam {
559    fn update_batch(&mut self, batch: Vec<(&mut Tensor, &Tensor)>) -> Result<()> {
560        let mut updates = Vec::new();
561
562        for (param, grad) in batch {
563            match (param, grad) {
564                (Tensor::F32(p), Tensor::F32(g)) => {
565                    let param_id = format!("{:p}", p.as_ptr());
566                    updates.push((
567                        param_id,
568                        p.as_slice_mut().expect("array must have contiguous layout"),
569                        g.as_slice().expect("array must have contiguous layout"),
570                    ));
571                },
572                _ => {
573                    return Err(TrustformersError::tensor_op_error(
574                        "Unsupported tensor types",
575                        "update_batch",
576                    ))
577                },
578            }
579        }
580
581        self.update_parallel(updates)
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn test_parallel_config() {
591        let config = ParallelConfig::default();
592        assert_eq!(config.num_threads, 0); // Auto-detect
593        assert!(config.enable_work_stealing);
594
595        let cpu_config = ParallelConfig::cpu_optimized();
596        assert_eq!(cpu_config.num_threads, num_cpus::get());
597
598        let effective_threads = config.effective_num_threads();
599        assert!(effective_threads > 0);
600        assert_eq!(effective_threads, num_cpus::get());
601    }
602
603    #[test]
604    fn test_parallel_optimizer_state() {
605        let config = ParallelConfig::default();
606        let state = ParallelOptimizerState::new(config);
607
608        assert_eq!(state.get_step(), 0);
609        state.step();
610        assert_eq!(state.get_step(), 1);
611
612        let param_state = state.get_or_create_state("test_param".to_string(), 100);
613        let locked_state = param_state.lock().expect("Parallel optimizer state lock poisoned");
614        assert_eq!(locked_state.momentum.len(), 100);
615        assert_eq!(locked_state.variance.len(), 100);
616    }
617
618    #[test]
619    fn test_parallel_adam() {
620        let optimizer = ParallelAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
621        assert_eq!(optimizer.get_lr(), 1e-3);
622        assert_eq!(optimizer.betas, (0.9, 0.999));
623
624        let stats = optimizer.parallel_stats();
625        assert!(stats.num_threads > 0);
626        assert_eq!(stats.current_step, 0);
627    }
628
629    #[test]
630    fn test_should_parallelize() {
631        let config = ParallelConfig {
632            min_params_per_thread: 1000,
633            num_threads: 4,
634            ..Default::default()
635        };
636        let optimizer = ParallelAdam::with_config(1e-3, (0.9, 0.999), 1e-8, 0.01, config);
637
638        // Small workload - should not parallelize
639        let mut small_params = [0.0; 100];
640        let small_grads = [0.0; 100];
641        let small_updates = vec![(
642            "param1".to_string(),
643            &mut small_params[..],
644            &small_grads[..],
645        )];
646        assert!(!optimizer.should_parallelize(&small_updates));
647
648        // Large workload - should parallelize
649        let mut large_params = [0.0; 5000];
650        let large_grads = [0.0; 5000];
651        let large_updates = vec![(
652            "param1".to_string(),
653            &mut large_params[..],
654            &large_grads[..],
655        )];
656        assert!(optimizer.should_parallelize(&large_updates));
657    }
658
659    #[test]
660    fn test_parallel_stats() {
661        let optimizer = ParallelAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
662        let stats = optimizer.parallel_stats();
663
664        let speedup = stats.theoretical_speedup(1000.0);
665        assert!(speedup > 1.0);
666        assert!(speedup <= stats.num_threads as f64);
667
668        let suggestions = stats.optimization_suggestions();
669        assert!(!suggestions.is_empty());
670    }
671
672    #[test]
673    fn test_memory_usage() {
674        let config = ParallelConfig::default();
675        let state = ParallelOptimizerState::new(config);
676
677        // Create some parameter states
678        state.get_or_create_state("param1".to_string(), 1000);
679        state.get_or_create_state("param2".to_string(), 2000);
680
681        let memory_stats = state.memory_usage();
682        assert_eq!(memory_stats.num_parameters, 2);
683        assert_eq!(memory_stats.momentum_elements, 3000);
684        assert_eq!(memory_stats.variance_elements, 3000);
685    }
686}