Skip to main content

trustformers_optim/
parallel.rs

1//! Parallel optimization algorithms for multi-threaded training.
2//!
3//! This module provides thread-safe optimizers that can leverage multiple CPU cores
4//! for parallel parameter updates, improving performance on multi-core systems.
5//!
6//! # Key Features
7//!
8//! - **Thread-Safe State Management**: Lock-free and fine-grained locking strategies
9//! - **Parallel Parameter Updates**: Distribute parameter updates across threads
10//! - **Work Stealing**: Dynamic load balancing for uneven parameter distributions
11//! - **NUMA Awareness**: Optimize for Non-Uniform Memory Access architectures
12//! - **Scalability**: Efficient scaling from 2 to 64+ cores
13
14use crate::common::{BiasCorrection, ParameterUpdate, StateMemoryStats};
15use scirs2_core::parallel_ops::*; // SciRS2 Integration Policy - replaces rayon
16use std::collections::HashMap;
17use std::sync::{Arc, Mutex, RwLock};
18use trustformers_core::errors::{Result, TrustformersError};
19use trustformers_core::tensor::Tensor;
20use trustformers_core::traits::Optimizer;
21
22/// Configuration for parallel optimization.
23#[derive(Debug, Clone)]
24pub struct ParallelConfig {
25    /// Number of worker threads (0 = auto-detect)
26    pub num_threads: usize,
27    /// Minimum parameters per thread to justify parallelization
28    pub min_params_per_thread: usize,
29    /// Enable work stealing for load balancing
30    pub enable_work_stealing: bool,
31    /// Enable NUMA-aware thread pinning
32    pub numa_aware: bool,
33    /// Chunk size for parameter processing
34    pub chunk_size: usize,
35    /// Enable lock-free optimizations where possible
36    pub lock_free: bool,
37}
38
39impl Default for ParallelConfig {
40    fn default() -> Self {
41        Self {
42            num_threads: 0, // Auto-detect
43            min_params_per_thread: 1000,
44            enable_work_stealing: true,
45            numa_aware: false,
46            chunk_size: 1024,
47            lock_free: true,
48        }
49    }
50}
51
52impl ParallelConfig {
53    /// Creates configuration optimized for CPU-bound workloads.
54    pub fn cpu_optimized() -> Self {
55        Self {
56            num_threads: num_cpus::get(),
57            chunk_size: 512,
58            enable_work_stealing: true,
59            ..Default::default()
60        }
61    }
62
63    /// Creates configuration for large model training.
64    pub fn large_model() -> Self {
65        Self {
66            num_threads: num_cpus::get(),
67            min_params_per_thread: 10000,
68            chunk_size: 4096,
69            numa_aware: true,
70            ..Default::default()
71        }
72    }
73
74    /// Creates configuration for memory-bound workloads.
75    pub fn memory_bound() -> Self {
76        Self {
77            num_threads: (num_cpus::get() / 2).max(1),
78            chunk_size: 2048,
79            numa_aware: true,
80            ..Default::default()
81        }
82    }
83
84    /// Gets the effective number of threads.
85    pub fn effective_num_threads(&self) -> usize {
86        if self.num_threads == 0 {
87            num_cpus::get()
88        } else {
89            self.num_threads
90        }
91    }
92}
93
94/// Thread-safe optimizer state with fine-grained locking.
95#[derive(Debug)]
96pub struct ParallelOptimizerState {
97    /// Per-parameter state with individual locks
98    parameter_states: RwLock<HashMap<String, Arc<Mutex<ParameterState>>>>,
99    /// Global step counter
100    global_step: Arc<std::sync::atomic::AtomicUsize>,
101    /// Parallel configuration
102    config: ParallelConfig,
103}
104
105/// Individual parameter state with momentum and variance.
106#[derive(Debug)]
107pub struct ParameterState {
108    pub momentum: Vec<f32>,
109    pub variance: Vec<f32>,
110    pub step: usize,
111    pub last_update: std::time::Instant,
112}
113
114impl ParameterState {
115    fn new(size: usize) -> Self {
116        Self {
117            momentum: vec![0.0; size],
118            variance: vec![0.0; size],
119            step: 0,
120            last_update: std::time::Instant::now(),
121        }
122    }
123}
124
125impl ParallelOptimizerState {
126    /// Creates a new parallel optimizer state.
127    pub fn new(config: ParallelConfig) -> Self {
128        Self {
129            parameter_states: RwLock::new(HashMap::new()),
130            global_step: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
131            config,
132        }
133    }
134
135    /// Gets or creates parameter state.
136    pub fn get_or_create_state(&self, param_id: String, size: usize) -> Arc<Mutex<ParameterState>> {
137        // Try read-only access first
138        {
139            let states = self.parameter_states.read().unwrap();
140            if let Some(state) = states.get(&param_id) {
141                return state.clone();
142            }
143        }
144
145        // Need to create new state - upgrade to write lock
146        let mut states = self.parameter_states.write().unwrap();
147        // Double-check pattern in case another thread created it
148        if let Some(state) = states.get(&param_id) {
149            return state.clone();
150        }
151
152        let new_state = Arc::new(Mutex::new(ParameterState::new(size)));
153        states.insert(param_id, new_state.clone());
154        new_state
155    }
156
157    /// Increments global step counter atomically.
158    pub fn step(&self) {
159        self.global_step.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
160    }
161
162    /// Gets current global step.
163    pub fn get_step(&self) -> usize {
164        self.global_step.load(std::sync::atomic::Ordering::Relaxed)
165    }
166
167    /// Gets memory usage statistics.
168    pub fn memory_usage(&self) -> StateMemoryStats {
169        let states = self.parameter_states.read().unwrap();
170        let mut total_momentum = 0;
171        let mut total_variance = 0;
172        let num_params = states.len();
173
174        for state_arc in states.values() {
175            if let Ok(state) = state_arc.try_lock() {
176                total_momentum += state.momentum.len();
177                total_variance += state.variance.len();
178            }
179        }
180
181        StateMemoryStats {
182            momentum_elements: total_momentum,
183            variance_elements: total_variance,
184            third_moment_elements: 0,
185            total_bytes: (total_momentum + total_variance) * std::mem::size_of::<f32>(),
186            num_parameters: num_params,
187        }
188    }
189
190    /// Clears all parameter states.
191    pub fn clear(&self) {
192        let mut states = self.parameter_states.write().unwrap();
193        states.clear();
194        self.global_step.store(0, std::sync::atomic::Ordering::Relaxed);
195    }
196}
197
198/// Parallel Adam optimizer with multi-threaded parameter updates.
199#[derive(Debug)]
200pub struct ParallelAdam {
201    /// Learning rate
202    lr: f32,
203    /// Beta coefficients
204    betas: (f32, f32),
205    /// Epsilon for numerical stability
206    eps: f32,
207    /// Weight decay coefficient
208    weight_decay: f32,
209    /// Parallel optimizer state
210    state: ParallelOptimizerState,
211}
212
213impl ParallelAdam {
214    /// Creates a new parallel Adam optimizer.
215    pub fn new(lr: f32, betas: (f32, f32), eps: f32, weight_decay: f32) -> Self {
216        Self::with_config(lr, betas, eps, weight_decay, ParallelConfig::default())
217    }
218
219    /// Creates a parallel Adam optimizer with custom configuration.
220    pub fn with_config(
221        lr: f32,
222        betas: (f32, f32),
223        eps: f32,
224        weight_decay: f32,
225        config: ParallelConfig,
226    ) -> Self {
227        Self {
228            lr,
229            betas,
230            eps,
231            weight_decay,
232            state: ParallelOptimizerState::new(config),
233        }
234    }
235
236    /// Updates multiple parameters in parallel.
237    pub fn update_parallel(&self, updates: Vec<(String, &mut [f32], &[f32])>) -> Result<()> {
238        let _chunk_size = self.state.config.chunk_size;
239        let min_params = self.state.config.min_params_per_thread;
240
241        if updates.len() < min_params || !self.should_parallelize(&updates) {
242            // Use sequential processing for small workloads
243            return self.update_sequential(updates);
244        }
245
246        // Parallel processing using rayon
247        let results: Result<Vec<()>> = updates
248            .into_par_iter()
249            .with_min_len(1)
250            .map(|(param_id, param, grad)| self.update_single_parameter(param_id, param, grad))
251            .collect();
252
253        results.map(|_| ())
254    }
255
256    /// Updates parameters sequentially.
257    fn update_sequential(&self, updates: Vec<(String, &mut [f32], &[f32])>) -> Result<()> {
258        for (param_id, param, grad) in updates {
259            self.update_single_parameter(param_id, param, grad)?;
260        }
261        Ok(())
262    }
263
264    /// Updates a single parameter with parallel chunk processing.
265    fn update_single_parameter(
266        &self,
267        param_id: String,
268        param: &mut [f32],
269        grad: &[f32],
270    ) -> Result<()> {
271        if param.len() != grad.len() {
272            return Err(TrustformersError::tensor_op_error(
273                "Parameter and gradient size mismatch",
274                "update_single_parameter",
275            ));
276        }
277
278        let size = param.len();
279        let state_arc = self.state.get_or_create_state(param_id, size);
280        let chunk_size = self.state.config.chunk_size;
281
282        // Lock the parameter state
283        let mut param_state = state_arc.lock().expect("Parallel optimizer state lock poisoned");
284        param_state.step += 1;
285        param_state.last_update = std::time::Instant::now();
286
287        let step = param_state.step;
288        let (bias_correction1, bias_correction2) =
289            BiasCorrection::compute_adam_corrections(self.betas.0, self.betas.1, step);
290
291        // Determine if we should parallelize this parameter
292        let should_parallelize = size >= chunk_size * 2 && self.state.config.num_threads > 1;
293        if should_parallelize {
294            // Parallel chunk processing
295            let ParameterState {
296                ref mut momentum,
297                ref mut variance,
298                ..
299            } = *param_state;
300            self.update_parameter_parallel(
301                param,
302                grad,
303                momentum,
304                variance,
305                bias_correction1,
306                bias_correction2,
307                chunk_size,
308            );
309        } else {
310            // Sequential processing for small parameters
311            let ParameterState {
312                ref mut momentum,
313                ref mut variance,
314                ..
315            } = *param_state;
316            self.update_parameter_sequential(
317                param,
318                grad,
319                momentum,
320                variance,
321                bias_correction1,
322                bias_correction2,
323            );
324        }
325
326        Ok(())
327    }
328
329    /// Updates parameter using parallel chunk processing.
330    fn update_parameter_parallel(
331        &self,
332        param: &mut [f32],
333        grad: &[f32],
334        momentum: &mut [f32],
335        variance: &mut [f32],
336        bias_correction1: f32,
337        bias_correction2: f32,
338        chunk_size: usize,
339    ) {
340        // Use parallel iterators for chunk-based processing
341        param
342            .par_chunks_mut(chunk_size)
343            .zip(grad.par_chunks(chunk_size))
344            .zip(momentum.par_chunks_mut(chunk_size))
345            .zip(variance.par_chunks_mut(chunk_size))
346            .for_each(|(((p_chunk, g_chunk), m_chunk), v_chunk)| {
347                self.process_chunk(
348                    p_chunk,
349                    g_chunk,
350                    m_chunk,
351                    v_chunk,
352                    bias_correction1,
353                    bias_correction2,
354                );
355            });
356    }
357
358    /// Updates parameter sequentially.
359    fn update_parameter_sequential(
360        &self,
361        param: &mut [f32],
362        grad: &[f32],
363        momentum: &mut [f32],
364        variance: &mut [f32],
365        bias_correction1: f32,
366        bias_correction2: f32,
367    ) {
368        self.process_chunk(
369            param,
370            grad,
371            momentum,
372            variance,
373            bias_correction1,
374            bias_correction2,
375        );
376    }
377
378    /// Processes a chunk of parameters.
379    #[inline]
380    fn process_chunk(
381        &self,
382        param_chunk: &mut [f32],
383        grad_chunk: &[f32],
384        momentum_chunk: &mut [f32],
385        variance_chunk: &mut [f32],
386        bias_correction1: f32,
387        bias_correction2: f32,
388    ) {
389        // Use the minimum length to avoid index out of bounds
390        let len = param_chunk
391            .len()
392            .min(grad_chunk.len())
393            .min(momentum_chunk.len())
394            .min(variance_chunk.len());
395
396        for i in 0..len {
397            let grad_val = grad_chunk[i] + self.weight_decay * param_chunk[i];
398
399            // Update momentum and variance
400            ParameterUpdate::update_ema(&mut momentum_chunk[i], grad_val, self.betas.0);
401            ParameterUpdate::update_ema(&mut variance_chunk[i], grad_val * grad_val, self.betas.1);
402
403            // Apply bias-corrected update
404            let m_hat = momentum_chunk[i] / bias_correction1;
405            let v_hat = variance_chunk[i] / bias_correction2;
406
407            ParameterUpdate::adam_update(&mut param_chunk[i], self.lr, m_hat, v_hat, self.eps);
408        }
409    }
410
411    /// Determines if parallelization should be used based on workload.
412    fn should_parallelize(&self, updates: &[(String, &mut [f32], &[f32])]) -> bool {
413        let total_elements: usize = updates.iter().map(|(_, param, _)| param.len()).sum();
414        let num_threads = self.state.config.effective_num_threads();
415
416        total_elements >= self.state.config.min_params_per_thread * num_threads
417    }
418
419    /// Gets parallel performance statistics.
420    pub fn parallel_stats(&self) -> ParallelStats {
421        let memory_stats = self.state.memory_usage();
422        let num_threads = self.state.config.effective_num_threads();
423
424        ParallelStats {
425            num_threads,
426            memory_stats,
427            config: self.state.config.clone(),
428            current_step: self.state.get_step(),
429        }
430    }
431
432    /// Configures thread pool for optimal performance.
433    pub fn configure_thread_pool(&self) -> Result<()> {
434        let num_threads = self.state.config.effective_num_threads();
435
436        ThreadPoolBuilder::new().num_threads(num_threads).build_global().map_err(|e| {
437            TrustformersError::tensor_op_error(
438                &format!("Failed to configure thread pool: {}", e),
439                "configure_thread_pool",
440            )
441        })?;
442
443        Ok(())
444    }
445}
446
447impl Optimizer for ParallelAdam {
448    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
449        match (parameter, grad) {
450            (Tensor::F32(param), Tensor::F32(grad_arr)) => {
451                let param_id = format!("{:p}", param.as_ptr());
452                self.update_single_parameter(
453                    param_id,
454                    param.as_slice_mut().unwrap(),
455                    grad_arr.as_slice().unwrap(),
456                )
457            },
458            _ => Err(TrustformersError::tensor_op_error(
459                "Unsupported tensor types for ParallelAdam",
460                "update",
461            )),
462        }
463    }
464
465    fn zero_grad(&mut self) {
466        // No explicit gradient storage
467    }
468
469    fn step(&mut self) {
470        self.state.step();
471    }
472
473    fn get_lr(&self) -> f32 {
474        self.lr
475    }
476
477    fn set_lr(&mut self, lr: f32) {
478        self.lr = lr;
479    }
480}
481
482/// Performance statistics for parallel optimization.
483#[derive(Debug, Clone)]
484pub struct ParallelStats {
485    /// Number of worker threads
486    pub num_threads: usize,
487    /// Memory usage statistics
488    pub memory_stats: StateMemoryStats,
489    /// Parallel configuration
490    pub config: ParallelConfig,
491    /// Current optimization step
492    pub current_step: usize,
493}
494
495impl ParallelStats {
496    /// Calculates theoretical speedup based on workload.
497    pub fn theoretical_speedup(&self, _sequential_time_ms: f64) -> f64 {
498        // Simple Amdahl's law approximation
499        let parallel_fraction = 0.95; // Assume 95% of work can be parallelized
500        let num_threads = self.num_threads as f64;
501
502        1.0 / ((1.0 - parallel_fraction) + (parallel_fraction / num_threads))
503    }
504
505    /// Suggests optimization improvements.
506    pub fn optimization_suggestions(&self) -> Vec<String> {
507        let mut suggestions = Vec::new();
508
509        if self.num_threads == 1 {
510            suggestions.push(
511                "Consider increasing number of threads for better parallelization".to_string(),
512            );
513        }
514
515        if self.num_threads > num_cpus::get() {
516            suggestions.push("Number of threads exceeds CPU cores; consider reducing".to_string());
517        }
518
519        if self.config.chunk_size < 256 {
520            suggestions
521                .push("Small chunk size may cause overhead; consider increasing".to_string());
522        }
523
524        if self.config.chunk_size > 8192 {
525            suggestions.push("Large chunk size may reduce parallelization efficiency".to_string());
526        }
527
528        if !self.config.enable_work_stealing {
529            suggestions.push("Enable work stealing for better load balancing".to_string());
530        }
531
532        if suggestions.is_empty() {
533            suggestions.push("Parallel configuration appears optimal".to_string());
534        }
535
536        suggestions
537    }
538}
539
540/// Batch parameter update interface for better parallelization.
541pub trait BatchUpdate {
542    /// Updates multiple parameters in a single batch operation.
543    fn update_batch(&mut self, batch: Vec<(&mut Tensor, &Tensor)>) -> Result<()>;
544}
545
546impl BatchUpdate for ParallelAdam {
547    fn update_batch(&mut self, batch: Vec<(&mut Tensor, &Tensor)>) -> Result<()> {
548        let mut updates = Vec::new();
549
550        for (param, grad) in batch {
551            match (param, grad) {
552                (Tensor::F32(p), Tensor::F32(g)) => {
553                    let param_id = format!("{:p}", p.as_ptr());
554                    updates.push((param_id, p.as_slice_mut().unwrap(), g.as_slice().unwrap()));
555                },
556                _ => {
557                    return Err(TrustformersError::tensor_op_error(
558                        "Unsupported tensor types",
559                        "update_batch",
560                    ))
561                },
562            }
563        }
564
565        self.update_parallel(updates)
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[test]
574    fn test_parallel_config() {
575        let config = ParallelConfig::default();
576        assert_eq!(config.num_threads, 0); // Auto-detect
577        assert!(config.enable_work_stealing);
578
579        let cpu_config = ParallelConfig::cpu_optimized();
580        assert_eq!(cpu_config.num_threads, num_cpus::get());
581
582        let effective_threads = config.effective_num_threads();
583        assert!(effective_threads > 0);
584        assert_eq!(effective_threads, num_cpus::get());
585    }
586
587    #[test]
588    fn test_parallel_optimizer_state() {
589        let config = ParallelConfig::default();
590        let state = ParallelOptimizerState::new(config);
591
592        assert_eq!(state.get_step(), 0);
593        state.step();
594        assert_eq!(state.get_step(), 1);
595
596        let param_state = state.get_or_create_state("test_param".to_string(), 100);
597        let locked_state = param_state.lock().expect("Parallel optimizer state lock poisoned");
598        assert_eq!(locked_state.momentum.len(), 100);
599        assert_eq!(locked_state.variance.len(), 100);
600    }
601
602    #[test]
603    fn test_parallel_adam() {
604        let optimizer = ParallelAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
605        assert_eq!(optimizer.get_lr(), 1e-3);
606        assert_eq!(optimizer.betas, (0.9, 0.999));
607
608        let stats = optimizer.parallel_stats();
609        assert!(stats.num_threads > 0);
610        assert_eq!(stats.current_step, 0);
611    }
612
613    #[test]
614    fn test_should_parallelize() {
615        let config = ParallelConfig {
616            min_params_per_thread: 1000,
617            num_threads: 4,
618            ..Default::default()
619        };
620        let optimizer = ParallelAdam::with_config(1e-3, (0.9, 0.999), 1e-8, 0.01, config);
621
622        // Small workload - should not parallelize
623        let mut small_params = [0.0; 100];
624        let small_grads = [0.0; 100];
625        let small_updates = vec![(
626            "param1".to_string(),
627            &mut small_params[..],
628            &small_grads[..],
629        )];
630        assert!(!optimizer.should_parallelize(&small_updates));
631
632        // Large workload - should parallelize
633        let mut large_params = [0.0; 5000];
634        let large_grads = [0.0; 5000];
635        let large_updates = vec![(
636            "param1".to_string(),
637            &mut large_params[..],
638            &large_grads[..],
639        )];
640        assert!(optimizer.should_parallelize(&large_updates));
641    }
642
643    #[test]
644    fn test_parallel_stats() {
645        let optimizer = ParallelAdam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
646        let stats = optimizer.parallel_stats();
647
648        let speedup = stats.theoretical_speedup(1000.0);
649        assert!(speedup > 1.0);
650        assert!(speedup <= stats.num_threads as f64);
651
652        let suggestions = stats.optimization_suggestions();
653        assert!(!suggestions.is_empty());
654    }
655
656    #[test]
657    fn test_memory_usage() {
658        let config = ParallelConfig::default();
659        let state = ParallelOptimizerState::new(config);
660
661        // Create some parameter states
662        state.get_or_create_state("param1".to_string(), 1000);
663        state.get_or_create_state("param2".to_string(), 2000);
664
665        let memory_stats = state.memory_usage();
666        assert_eq!(memory_stats.num_parameters, 2);
667        assert_eq!(memory_stats.momentum_elements, 3000);
668        assert_eq!(memory_stats.variance_elements, 3000);
669    }
670}