Skip to main content

trustformers_optim/
async_optim.rs

1//! # Asynchronous Optimization Methods
2//!
3//! This module implements asynchronous optimization algorithms for distributed
4//! training where workers can update parameters without strict synchronization.
5//!
6//! ## Available Methods
7//!
8//! - **Async SGD**: Asynchronous stochastic gradient descent
9//! - **Hogwild!**: Lock-free asynchronous SGD for sparse features
10//! - **Delayed Gradient**: Methods that handle stale gradients
11//! - **Elastic Averaging SGD**: Combines local and global parameter averaging
12
13use anyhow::Result;
14use parking_lot::{Mutex, RwLock};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20use trustformers_core::tensor::Tensor;
21
22/// Configuration for asynchronous SGD.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct AsyncSGDConfig {
25    /// Learning rate
26    pub learning_rate: f32,
27    /// Momentum coefficient
28    pub momentum: f32,
29    /// Weight decay
30    pub weight_decay: f32,
31    /// Maximum allowed staleness for gradient updates
32    pub max_staleness: usize,
33    /// Staleness adaptive factor
34    pub staleness_factor: f32,
35}
36
37impl Default for AsyncSGDConfig {
38    fn default() -> Self {
39        Self {
40            learning_rate: 1e-3,
41            momentum: 0.9,
42            weight_decay: 0.0,
43            max_staleness: 10,
44            staleness_factor: 0.9,
45        }
46    }
47}
48
49/// Configuration for Hogwild! optimizer.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct HogwildConfig {
52    /// Learning rate
53    pub learning_rate: f32,
54    /// Sparse update ratio (fraction of parameters updated per step)
55    pub sparse_ratio: f32,
56    /// Maximum number of concurrent workers
57    pub max_workers: usize,
58}
59
60impl Default for HogwildConfig {
61    fn default() -> Self {
62        Self {
63            learning_rate: 1e-3,
64            sparse_ratio: 0.1,
65            max_workers: 4,
66        }
67    }
68}
69
70/// Configuration for delayed gradient methods.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct DelayedGradientConfig {
73    /// Base learning rate
74    pub learning_rate: f32,
75    /// Maximum gradient delay (in steps)
76    pub max_delay: usize,
77    /// Delay compensation method
78    pub compensation_method: DelayCompensationMethod,
79    /// Compensation factor
80    pub compensation_factor: f32,
81}
82
83impl Default for DelayedGradientConfig {
84    fn default() -> Self {
85        Self {
86            learning_rate: 1e-3,
87            max_delay: 20,
88            compensation_method: DelayCompensationMethod::LinearDecay,
89            compensation_factor: 0.5,
90        }
91    }
92}
93
94/// Methods for compensating gradient delays.
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub enum DelayCompensationMethod {
97    /// No compensation
98    None,
99    /// Linear decay based on delay
100    LinearDecay,
101    /// Exponential decay based on delay
102    ExponentialDecay,
103    /// Adaptive compensation
104    Adaptive,
105}
106
107/// Configuration for Elastic Averaging SGD.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ElasticAveragingConfig {
110    /// Learning rate
111    pub learning_rate: f32,
112    /// Elastic force coefficient
113    pub alpha: f32,
114    /// Communication period (steps between synchronization)
115    pub tau: usize,
116    /// Beta parameter for moving average
117    pub beta: f32,
118}
119
120impl Default for ElasticAveragingConfig {
121    fn default() -> Self {
122        Self {
123            learning_rate: 1e-3,
124            alpha: 0.6,
125            tau: 10,
126            beta: 0.9,
127        }
128    }
129}
130
131/// Shared parameter server for asynchronous optimization.
132pub struct ParameterServer {
133    /// Global parameters
134    parameters: Arc<RwLock<Vec<Tensor>>>,
135    /// Global step counter
136    global_step: AtomicUsize,
137    /// Parameter version counters
138    version_counters: Arc<Mutex<Vec<usize>>>,
139    /// Worker update timestamps
140    worker_timestamps: Arc<Mutex<HashMap<usize, Instant>>>,
141}
142
143impl ParameterServer {
144    /// Create a new parameter server.
145    pub fn new(initial_parameters: Vec<Tensor>) -> Self {
146        let param_count = initial_parameters.len();
147        Self {
148            parameters: Arc::new(RwLock::new(initial_parameters)),
149            global_step: AtomicUsize::new(0),
150            version_counters: Arc::new(Mutex::new(vec![0; param_count])),
151            worker_timestamps: Arc::new(Mutex::new(HashMap::new())),
152        }
153    }
154
155    /// Get current parameters for a worker.
156    pub fn get_parameters(&self, worker_id: usize) -> Result<(Vec<Tensor>, Vec<usize>)> {
157        let params = self.parameters.read().clone();
158        let versions = self.version_counters.lock().clone();
159
160        // Update worker timestamp
161        let mut timestamps = self.worker_timestamps.lock();
162        timestamps.insert(worker_id, Instant::now());
163
164        Ok((params, versions))
165    }
166
167    /// Update parameters with gradients from a worker.
168    pub fn update_parameters(
169        &self,
170        worker_id: usize,
171        gradients: Vec<Tensor>,
172        param_versions: Vec<usize>,
173        learning_rate: f32,
174    ) -> Result<()> {
175        let _current_step = self.global_step.load(Ordering::SeqCst);
176
177        // Check staleness
178        let staleness = self.compute_staleness(worker_id, &param_versions)?;
179        if staleness > 10 {
180            // Skip very stale updates
181            return Ok(());
182        }
183
184        // Apply staleness compensation
185        let compensated_lr = learning_rate * (1.0 / (1.0 + staleness as f32 * 0.1));
186
187        // Update parameters
188        {
189            let mut params = self.parameters.write();
190            let mut versions = self.version_counters.lock();
191
192            for (i, gradient) in gradients.iter().enumerate() {
193                if i < params.len() {
194                    let update = gradient.mul_scalar(compensated_lr)?;
195                    params[i] = params[i].sub(&update)?;
196                    versions[i] += 1;
197                }
198            }
199        }
200
201        self.global_step.fetch_add(1, Ordering::SeqCst);
202        Ok(())
203    }
204
205    fn compute_staleness(&self, _worker_id: usize, param_versions: &[usize]) -> Result<usize> {
206        let current_versions = self.version_counters.lock();
207        let max_staleness = param_versions
208            .iter()
209            .zip(current_versions.iter())
210            .map(|(old, new)| new.saturating_sub(*old))
211            .max()
212            .unwrap_or(0);
213        Ok(max_staleness)
214    }
215
216    /// Get current global step.
217    pub fn get_global_step(&self) -> usize {
218        self.global_step.load(Ordering::SeqCst)
219    }
220}
221
222/// Asynchronous SGD optimizer.
223pub struct AsyncSGD {
224    config: AsyncSGDConfig,
225    worker_id: usize,
226    parameter_server: Arc<ParameterServer>,
227    momentum_buffers: Vec<Tensor>,
228    local_parameters: Vec<Tensor>,
229    param_versions: Vec<usize>,
230    last_sync_step: usize,
231}
232
233impl AsyncSGD {
234    /// Create a new async SGD optimizer.
235    pub fn new(
236        config: AsyncSGDConfig,
237        worker_id: usize,
238        parameter_server: Arc<ParameterServer>,
239    ) -> Result<Self> {
240        let (params, versions) = parameter_server.get_parameters(worker_id)?;
241        let param_count = params.len();
242
243        Ok(Self {
244            config,
245            worker_id,
246            parameter_server,
247            momentum_buffers: (0..param_count)
248                .map(|i| Tensor::zeros(&params[i].shape()).map_err(anyhow::Error::from))
249                .collect::<Result<Vec<_>>>()?,
250            local_parameters: params,
251            param_versions: versions,
252            last_sync_step: 0,
253        })
254    }
255
256    /// Perform an optimization step.
257    pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
258        // Check if we need to sync with parameter server
259        let current_step = self.parameter_server.get_global_step();
260        let staleness = current_step - self.last_sync_step;
261
262        if staleness > self.config.max_staleness {
263            self.sync_with_server()?;
264        }
265
266        // Apply momentum and update local parameters
267        for (i, gradient) in gradients.iter().enumerate() {
268            if i < self.local_parameters.len() {
269                // Apply weight decay
270                let effective_grad = if self.config.weight_decay > 0.0 {
271                    gradient.add(&self.local_parameters[i].mul_scalar(self.config.weight_decay)?)?
272                } else {
273                    gradient.clone()
274                };
275
276                // Update momentum
277                self.momentum_buffers[i] = self.momentum_buffers[i]
278                    .mul_scalar(self.config.momentum)?
279                    .add(&effective_grad)?;
280
281                // Apply staleness compensation
282                let staleness_factor = self.config.staleness_factor.powi(staleness as i32);
283                let compensated_lr = self.config.learning_rate * staleness_factor;
284
285                // Update local parameters
286                let update = self.momentum_buffers[i].mul_scalar(compensated_lr)?;
287                self.local_parameters[i] = self.local_parameters[i].sub(&update)?;
288            }
289        }
290
291        // Send updates to parameter server periodically
292        if current_step % 5 == 0 {
293            self.push_to_server(gradients)?;
294        }
295
296        Ok(())
297    }
298
299    fn sync_with_server(&mut self) -> Result<()> {
300        let (params, versions) = self.parameter_server.get_parameters(self.worker_id)?;
301        self.local_parameters = params;
302        self.param_versions = versions;
303        self.last_sync_step = self.parameter_server.get_global_step();
304        Ok(())
305    }
306
307    fn push_to_server(&self, gradients: &[Tensor]) -> Result<()> {
308        self.parameter_server.update_parameters(
309            self.worker_id,
310            gradients.to_vec(),
311            self.param_versions.clone(),
312            self.config.learning_rate,
313        )
314    }
315
316    /// Get current local parameters.
317    pub fn get_parameters(&self) -> &[Tensor] {
318        &self.local_parameters
319    }
320}
321
322/// Hogwild! optimizer for sparse features.
323pub struct Hogwild {
324    config: HogwildConfig,
325    #[allow(dead_code)]
326    worker_id: usize,
327    shared_parameters: Arc<RwLock<Vec<Tensor>>>,
328    local_step: usize,
329}
330
331impl Hogwild {
332    /// Create a new Hogwild! optimizer.
333    pub fn new(
334        config: HogwildConfig,
335        worker_id: usize,
336        shared_parameters: Arc<RwLock<Vec<Tensor>>>,
337    ) -> Self {
338        Self {
339            config,
340            worker_id,
341            shared_parameters,
342            local_step: 0,
343        }
344    }
345
346    /// Perform sparse parameter update.
347    pub fn sparse_step(&mut self, sparse_gradients: &[(usize, Tensor)]) -> Result<()> {
348        // Lock-free updates for sparse gradients
349        // In practice, this would use atomic operations for true lock-free behavior
350
351        for &(param_idx, ref gradient) in sparse_gradients {
352            {
353                let params = self.shared_parameters.read();
354                if param_idx >= params.len() {
355                    continue;
356                }
357            } // Release read lock
358
359            // This is a simplified version - real Hogwild! uses lock-free atomic updates
360            let mut params_write = self.shared_parameters.write();
361            let update = gradient.mul_scalar(self.config.learning_rate)?;
362            params_write[param_idx] = params_write[param_idx].sub(&update)?;
363        }
364
365        self.local_step += 1;
366        Ok(())
367    }
368
369    /// Generate sparse gradient indices based on sparse ratio.
370    pub fn select_sparse_indices(&self, total_params: usize) -> Vec<usize> {
371        use scirs2_core::random::*; // SciRS2 Integration Policy
372
373        let num_sparse = (total_params as f32 * self.config.sparse_ratio) as usize;
374        let mut indices: Vec<usize> = (0..total_params).collect();
375        let mut rng = thread_rng();
376        indices.shuffle(rng.rng_mut());
377        indices.truncate(num_sparse);
378        indices
379    }
380}
381
382/// Delayed gradient optimizer.
383pub struct DelayedGradient {
384    config: DelayedGradientConfig,
385    parameters: Vec<Tensor>,
386    gradient_buffer: Vec<(Tensor, usize, Instant)>, // (gradient, delay, timestamp)
387    current_step: usize,
388}
389
390impl DelayedGradient {
391    /// Create a new delayed gradient optimizer.
392    pub fn new(config: DelayedGradientConfig, initial_parameters: Vec<Tensor>) -> Self {
393        Self {
394            config,
395            parameters: initial_parameters,
396            gradient_buffer: Vec::new(),
397            current_step: 0,
398        }
399    }
400
401    /// Add a delayed gradient to the buffer.
402    pub fn add_delayed_gradient(&mut self, gradient: Tensor, delay: usize) {
403        self.gradient_buffer.push((gradient, delay, Instant::now()));
404    }
405
406    /// Process delayed gradients and update parameters.
407    pub fn step(&mut self) -> Result<()> {
408        self.current_step += 1;
409
410        // Process gradients that are ready
411        let mut i = 0;
412        while i < self.gradient_buffer.len() {
413            let (ref gradient, delay, timestamp) = &self.gradient_buffer[i];
414            let age = timestamp.elapsed();
415
416            if age >= Duration::from_millis((*delay as u64) * 10) {
417                // Apply delay compensation
418                let compensation = self.compute_delay_compensation(*delay)?;
419                let compensated_lr = self.config.learning_rate * compensation;
420
421                // Update parameters
422                for (j, param) in self.parameters.iter_mut().enumerate() {
423                    if j < 1 {
424                        // Assuming single parameter for simplicity
425                        let update = gradient.mul_scalar(compensated_lr)?;
426                        *param = param.sub(&update)?;
427                    }
428                }
429
430                self.gradient_buffer.remove(i);
431            } else {
432                i += 1;
433            }
434        }
435
436        Ok(())
437    }
438
439    fn compute_delay_compensation(&self, delay: usize) -> Result<f32> {
440        if delay > self.config.max_delay {
441            return Ok(0.0); // Discard very old gradients
442        }
443
444        let delay_ratio = delay as f32 / self.config.max_delay as f32;
445
446        let compensation = match self.config.compensation_method {
447            DelayCompensationMethod::None => 1.0,
448            DelayCompensationMethod::LinearDecay => {
449                1.0 - delay_ratio * self.config.compensation_factor
450            },
451            DelayCompensationMethod::ExponentialDecay => {
452                (-delay_ratio * self.config.compensation_factor).exp()
453            },
454            DelayCompensationMethod::Adaptive => {
455                // Simple adaptive scheme
456                1.0 / (1.0 + delay_ratio * self.config.compensation_factor)
457            },
458        };
459
460        Ok(compensation.max(0.1)) // Minimum 10% of original learning rate
461    }
462
463    /// Get current parameters.
464    pub fn get_parameters(&self) -> &[Tensor] {
465        &self.parameters
466    }
467}
468
469/// Elastic Averaging SGD optimizer.
470pub struct ElasticAveraging {
471    config: ElasticAveragingConfig,
472    #[allow(dead_code)]
473    worker_id: usize,
474    local_parameters: Vec<Tensor>,
475    global_parameters: Arc<RwLock<Vec<Tensor>>>,
476    elastic_force: Vec<Tensor>,
477    local_step: usize,
478    last_communication: usize,
479}
480
481impl ElasticAveraging {
482    /// Create a new Elastic Averaging SGD optimizer.
483    pub fn new(
484        config: ElasticAveragingConfig,
485        worker_id: usize,
486        global_parameters: Arc<RwLock<Vec<Tensor>>>,
487    ) -> Result<Self> {
488        let global_params = global_parameters.read().clone();
489        let param_count = global_params.len();
490
491        Ok(Self {
492            config,
493            worker_id,
494            local_parameters: global_params.clone(),
495            global_parameters,
496            elastic_force: (0..param_count)
497                .map(|i| Tensor::zeros(&global_params[i].shape()).map_err(anyhow::Error::from))
498                .collect::<Result<Vec<_>>>()?,
499            local_step: 0,
500            last_communication: 0,
501        })
502    }
503
504    /// Perform optimization step with elastic averaging.
505    pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
506        // Update local parameters with gradients
507        for (i, gradient) in gradients.iter().enumerate() {
508            if i < self.local_parameters.len() {
509                let update = gradient.mul_scalar(self.config.learning_rate)?;
510                self.local_parameters[i] = self.local_parameters[i].sub(&update)?;
511            }
512        }
513
514        // Apply elastic force
515        let global_params = self.global_parameters.read();
516        for i in 0..self.local_parameters.len() {
517            let diff = self.local_parameters[i].sub(&global_params[i])?;
518            self.elastic_force[i] = diff.mul_scalar(self.config.alpha)?;
519            let elastic_update = self.elastic_force[i].mul_scalar(self.config.learning_rate)?;
520            self.local_parameters[i] = self.local_parameters[i].sub(&elastic_update)?;
521        }
522        drop(global_params);
523
524        self.local_step += 1;
525
526        // Communicate with global parameters periodically
527        if self.local_step - self.last_communication >= self.config.tau {
528            self.communicate_with_global()?;
529            self.last_communication = self.local_step;
530        }
531
532        Ok(())
533    }
534
535    fn communicate_with_global(&mut self) -> Result<()> {
536        let mut global_params = self.global_parameters.write();
537
538        // Update global parameters with moving average
539        for i in 0..global_params.len() {
540            let local_contrib = self.local_parameters[i].mul_scalar(1.0 - self.config.beta)?;
541            let global_contrib = global_params[i].mul_scalar(self.config.beta)?;
542            global_params[i] = local_contrib.add(&global_contrib)?;
543        }
544
545        // Update local parameters from global
546        self.local_parameters = global_params.clone();
547
548        Ok(())
549    }
550
551    /// Get current local parameters.
552    pub fn get_parameters(&self) -> &[Tensor] {
553        &self.local_parameters
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_async_sgd_config() {
563        let config = AsyncSGDConfig::default();
564        assert_eq!(config.learning_rate, 1e-3);
565        assert_eq!(config.momentum, 0.9);
566        assert_eq!(config.max_staleness, 10);
567    }
568
569    #[test]
570    fn test_hogwild_config() {
571        let config = HogwildConfig::default();
572        assert_eq!(config.learning_rate, 1e-3);
573        assert_eq!(config.sparse_ratio, 0.1);
574        assert_eq!(config.max_workers, 4);
575    }
576
577    #[test]
578    fn test_delayed_gradient_config() {
579        let config = DelayedGradientConfig::default();
580        assert_eq!(config.learning_rate, 1e-3);
581        assert_eq!(config.max_delay, 20);
582        assert!(matches!(
583            config.compensation_method,
584            DelayCompensationMethod::LinearDecay
585        ));
586    }
587
588    #[test]
589    fn test_parameter_server_creation() {
590        let params = vec![Tensor::zeros(&[10]).unwrap()];
591        let server = ParameterServer::new(params);
592        assert_eq!(server.get_global_step(), 0);
593    }
594
595    #[test]
596    fn test_elastic_averaging_config() {
597        let config = ElasticAveragingConfig::default();
598        assert_eq!(config.learning_rate, 1e-3);
599        assert_eq!(config.alpha, 0.6);
600        assert_eq!(config.tau, 10);
601        assert_eq!(config.beta, 0.9);
602    }
603}