1use anyhow::Result;
14use parking_lot::{Mutex, RwLock};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20use trustformers_core::tensor::Tensor;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct AsyncSGDConfig {
25 pub learning_rate: f32,
27 pub momentum: f32,
29 pub weight_decay: f32,
31 pub max_staleness: usize,
33 pub staleness_factor: f32,
35}
36
37impl Default for AsyncSGDConfig {
38 fn default() -> Self {
39 Self {
40 learning_rate: 1e-3,
41 momentum: 0.9,
42 weight_decay: 0.0,
43 max_staleness: 10,
44 staleness_factor: 0.9,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct HogwildConfig {
52 pub learning_rate: f32,
54 pub sparse_ratio: f32,
56 pub max_workers: usize,
58}
59
60impl Default for HogwildConfig {
61 fn default() -> Self {
62 Self {
63 learning_rate: 1e-3,
64 sparse_ratio: 0.1,
65 max_workers: 4,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct DelayedGradientConfig {
73 pub learning_rate: f32,
75 pub max_delay: usize,
77 pub compensation_method: DelayCompensationMethod,
79 pub compensation_factor: f32,
81}
82
83impl Default for DelayedGradientConfig {
84 fn default() -> Self {
85 Self {
86 learning_rate: 1e-3,
87 max_delay: 20,
88 compensation_method: DelayCompensationMethod::LinearDecay,
89 compensation_factor: 0.5,
90 }
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub enum DelayCompensationMethod {
97 None,
99 LinearDecay,
101 ExponentialDecay,
103 Adaptive,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ElasticAveragingConfig {
110 pub learning_rate: f32,
112 pub alpha: f32,
114 pub tau: usize,
116 pub beta: f32,
118}
119
120impl Default for ElasticAveragingConfig {
121 fn default() -> Self {
122 Self {
123 learning_rate: 1e-3,
124 alpha: 0.6,
125 tau: 10,
126 beta: 0.9,
127 }
128 }
129}
130
131pub struct ParameterServer {
133 parameters: Arc<RwLock<Vec<Tensor>>>,
135 global_step: AtomicUsize,
137 version_counters: Arc<Mutex<Vec<usize>>>,
139 worker_timestamps: Arc<Mutex<HashMap<usize, Instant>>>,
141}
142
143impl ParameterServer {
144 pub fn new(initial_parameters: Vec<Tensor>) -> Self {
146 let param_count = initial_parameters.len();
147 Self {
148 parameters: Arc::new(RwLock::new(initial_parameters)),
149 global_step: AtomicUsize::new(0),
150 version_counters: Arc::new(Mutex::new(vec![0; param_count])),
151 worker_timestamps: Arc::new(Mutex::new(HashMap::new())),
152 }
153 }
154
155 pub fn get_parameters(&self, worker_id: usize) -> Result<(Vec<Tensor>, Vec<usize>)> {
157 let params = self.parameters.read().clone();
158 let versions = self.version_counters.lock().clone();
159
160 let mut timestamps = self.worker_timestamps.lock();
162 timestamps.insert(worker_id, Instant::now());
163
164 Ok((params, versions))
165 }
166
167 pub fn update_parameters(
169 &self,
170 worker_id: usize,
171 gradients: Vec<Tensor>,
172 param_versions: Vec<usize>,
173 learning_rate: f32,
174 ) -> Result<()> {
175 let _current_step = self.global_step.load(Ordering::SeqCst);
176
177 let staleness = self.compute_staleness(worker_id, ¶m_versions)?;
179 if staleness > 10 {
180 return Ok(());
182 }
183
184 let compensated_lr = learning_rate * (1.0 / (1.0 + staleness as f32 * 0.1));
186
187 {
189 let mut params = self.parameters.write();
190 let mut versions = self.version_counters.lock();
191
192 for (i, gradient) in gradients.iter().enumerate() {
193 if i < params.len() {
194 let update = gradient.mul_scalar(compensated_lr)?;
195 params[i] = params[i].sub(&update)?;
196 versions[i] += 1;
197 }
198 }
199 }
200
201 self.global_step.fetch_add(1, Ordering::SeqCst);
202 Ok(())
203 }
204
205 fn compute_staleness(&self, _worker_id: usize, param_versions: &[usize]) -> Result<usize> {
206 let current_versions = self.version_counters.lock();
207 let max_staleness = param_versions
208 .iter()
209 .zip(current_versions.iter())
210 .map(|(old, new)| new.saturating_sub(*old))
211 .max()
212 .unwrap_or(0);
213 Ok(max_staleness)
214 }
215
216 pub fn get_global_step(&self) -> usize {
218 self.global_step.load(Ordering::SeqCst)
219 }
220}
221
222pub struct AsyncSGD {
224 config: AsyncSGDConfig,
225 worker_id: usize,
226 parameter_server: Arc<ParameterServer>,
227 momentum_buffers: Vec<Tensor>,
228 local_parameters: Vec<Tensor>,
229 param_versions: Vec<usize>,
230 last_sync_step: usize,
231}
232
233impl AsyncSGD {
234 pub fn new(
236 config: AsyncSGDConfig,
237 worker_id: usize,
238 parameter_server: Arc<ParameterServer>,
239 ) -> Result<Self> {
240 let (params, versions) = parameter_server.get_parameters(worker_id)?;
241 let param_count = params.len();
242
243 Ok(Self {
244 config,
245 worker_id,
246 parameter_server,
247 momentum_buffers: (0..param_count)
248 .map(|i| Tensor::zeros(¶ms[i].shape()).map_err(anyhow::Error::from))
249 .collect::<Result<Vec<_>>>()?,
250 local_parameters: params,
251 param_versions: versions,
252 last_sync_step: 0,
253 })
254 }
255
256 pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
258 let current_step = self.parameter_server.get_global_step();
260 let staleness = current_step - self.last_sync_step;
261
262 if staleness > self.config.max_staleness {
263 self.sync_with_server()?;
264 }
265
266 for (i, gradient) in gradients.iter().enumerate() {
268 if i < self.local_parameters.len() {
269 let effective_grad = if self.config.weight_decay > 0.0 {
271 gradient.add(&self.local_parameters[i].mul_scalar(self.config.weight_decay)?)?
272 } else {
273 gradient.clone()
274 };
275
276 self.momentum_buffers[i] = self.momentum_buffers[i]
278 .mul_scalar(self.config.momentum)?
279 .add(&effective_grad)?;
280
281 let staleness_factor = self.config.staleness_factor.powi(staleness as i32);
283 let compensated_lr = self.config.learning_rate * staleness_factor;
284
285 let update = self.momentum_buffers[i].mul_scalar(compensated_lr)?;
287 self.local_parameters[i] = self.local_parameters[i].sub(&update)?;
288 }
289 }
290
291 if current_step % 5 == 0 {
293 self.push_to_server(gradients)?;
294 }
295
296 Ok(())
297 }
298
299 fn sync_with_server(&mut self) -> Result<()> {
300 let (params, versions) = self.parameter_server.get_parameters(self.worker_id)?;
301 self.local_parameters = params;
302 self.param_versions = versions;
303 self.last_sync_step = self.parameter_server.get_global_step();
304 Ok(())
305 }
306
307 fn push_to_server(&self, gradients: &[Tensor]) -> Result<()> {
308 self.parameter_server.update_parameters(
309 self.worker_id,
310 gradients.to_vec(),
311 self.param_versions.clone(),
312 self.config.learning_rate,
313 )
314 }
315
316 pub fn get_parameters(&self) -> &[Tensor] {
318 &self.local_parameters
319 }
320}
321
322pub struct Hogwild {
324 config: HogwildConfig,
325 #[allow(dead_code)]
326 worker_id: usize,
327 shared_parameters: Arc<RwLock<Vec<Tensor>>>,
328 local_step: usize,
329}
330
331impl Hogwild {
332 pub fn new(
334 config: HogwildConfig,
335 worker_id: usize,
336 shared_parameters: Arc<RwLock<Vec<Tensor>>>,
337 ) -> Self {
338 Self {
339 config,
340 worker_id,
341 shared_parameters,
342 local_step: 0,
343 }
344 }
345
346 pub fn sparse_step(&mut self, sparse_gradients: &[(usize, Tensor)]) -> Result<()> {
348 for &(param_idx, ref gradient) in sparse_gradients {
352 {
353 let params = self.shared_parameters.read();
354 if param_idx >= params.len() {
355 continue;
356 }
357 } let mut params_write = self.shared_parameters.write();
361 let update = gradient.mul_scalar(self.config.learning_rate)?;
362 params_write[param_idx] = params_write[param_idx].sub(&update)?;
363 }
364
365 self.local_step += 1;
366 Ok(())
367 }
368
369 pub fn select_sparse_indices(&self, total_params: usize) -> Vec<usize> {
371 use scirs2_core::random::*; let num_sparse = (total_params as f32 * self.config.sparse_ratio) as usize;
374 let mut indices: Vec<usize> = (0..total_params).collect();
375 let mut rng = thread_rng();
376 indices.shuffle(rng.rng_mut());
377 indices.truncate(num_sparse);
378 indices
379 }
380}
381
382pub struct DelayedGradient {
384 config: DelayedGradientConfig,
385 parameters: Vec<Tensor>,
386 gradient_buffer: Vec<(Tensor, usize, Instant)>, current_step: usize,
388}
389
390impl DelayedGradient {
391 pub fn new(config: DelayedGradientConfig, initial_parameters: Vec<Tensor>) -> Self {
393 Self {
394 config,
395 parameters: initial_parameters,
396 gradient_buffer: Vec::new(),
397 current_step: 0,
398 }
399 }
400
401 pub fn add_delayed_gradient(&mut self, gradient: Tensor, delay: usize) {
403 self.gradient_buffer.push((gradient, delay, Instant::now()));
404 }
405
406 pub fn step(&mut self) -> Result<()> {
408 self.current_step += 1;
409
410 let mut i = 0;
412 while i < self.gradient_buffer.len() {
413 let (ref gradient, delay, timestamp) = &self.gradient_buffer[i];
414 let age = timestamp.elapsed();
415
416 if age >= Duration::from_millis((*delay as u64) * 10) {
417 let compensation = self.compute_delay_compensation(*delay)?;
419 let compensated_lr = self.config.learning_rate * compensation;
420
421 for (j, param) in self.parameters.iter_mut().enumerate() {
423 if j < 1 {
424 let update = gradient.mul_scalar(compensated_lr)?;
426 *param = param.sub(&update)?;
427 }
428 }
429
430 self.gradient_buffer.remove(i);
431 } else {
432 i += 1;
433 }
434 }
435
436 Ok(())
437 }
438
439 fn compute_delay_compensation(&self, delay: usize) -> Result<f32> {
440 if delay > self.config.max_delay {
441 return Ok(0.0); }
443
444 let delay_ratio = delay as f32 / self.config.max_delay as f32;
445
446 let compensation = match self.config.compensation_method {
447 DelayCompensationMethod::None => 1.0,
448 DelayCompensationMethod::LinearDecay => {
449 1.0 - delay_ratio * self.config.compensation_factor
450 },
451 DelayCompensationMethod::ExponentialDecay => {
452 (-delay_ratio * self.config.compensation_factor).exp()
453 },
454 DelayCompensationMethod::Adaptive => {
455 1.0 / (1.0 + delay_ratio * self.config.compensation_factor)
457 },
458 };
459
460 Ok(compensation.max(0.1)) }
462
463 pub fn get_parameters(&self) -> &[Tensor] {
465 &self.parameters
466 }
467}
468
469pub struct ElasticAveraging {
471 config: ElasticAveragingConfig,
472 #[allow(dead_code)]
473 worker_id: usize,
474 local_parameters: Vec<Tensor>,
475 global_parameters: Arc<RwLock<Vec<Tensor>>>,
476 elastic_force: Vec<Tensor>,
477 local_step: usize,
478 last_communication: usize,
479}
480
481impl ElasticAveraging {
482 pub fn new(
484 config: ElasticAveragingConfig,
485 worker_id: usize,
486 global_parameters: Arc<RwLock<Vec<Tensor>>>,
487 ) -> Result<Self> {
488 let global_params = global_parameters.read().clone();
489 let param_count = global_params.len();
490
491 Ok(Self {
492 config,
493 worker_id,
494 local_parameters: global_params.clone(),
495 global_parameters,
496 elastic_force: (0..param_count)
497 .map(|i| Tensor::zeros(&global_params[i].shape()).map_err(anyhow::Error::from))
498 .collect::<Result<Vec<_>>>()?,
499 local_step: 0,
500 last_communication: 0,
501 })
502 }
503
504 pub fn step(&mut self, gradients: &[Tensor]) -> Result<()> {
506 for (i, gradient) in gradients.iter().enumerate() {
508 if i < self.local_parameters.len() {
509 let update = gradient.mul_scalar(self.config.learning_rate)?;
510 self.local_parameters[i] = self.local_parameters[i].sub(&update)?;
511 }
512 }
513
514 let global_params = self.global_parameters.read();
516 for i in 0..self.local_parameters.len() {
517 let diff = self.local_parameters[i].sub(&global_params[i])?;
518 self.elastic_force[i] = diff.mul_scalar(self.config.alpha)?;
519 let elastic_update = self.elastic_force[i].mul_scalar(self.config.learning_rate)?;
520 self.local_parameters[i] = self.local_parameters[i].sub(&elastic_update)?;
521 }
522 drop(global_params);
523
524 self.local_step += 1;
525
526 if self.local_step - self.last_communication >= self.config.tau {
528 self.communicate_with_global()?;
529 self.last_communication = self.local_step;
530 }
531
532 Ok(())
533 }
534
535 fn communicate_with_global(&mut self) -> Result<()> {
536 let mut global_params = self.global_parameters.write();
537
538 for i in 0..global_params.len() {
540 let local_contrib = self.local_parameters[i].mul_scalar(1.0 - self.config.beta)?;
541 let global_contrib = global_params[i].mul_scalar(self.config.beta)?;
542 global_params[i] = local_contrib.add(&global_contrib)?;
543 }
544
545 self.local_parameters = global_params.clone();
547
548 Ok(())
549 }
550
551 pub fn get_parameters(&self) -> &[Tensor] {
553 &self.local_parameters
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_async_sgd_config() {
563 let config = AsyncSGDConfig::default();
564 assert_eq!(config.learning_rate, 1e-3);
565 assert_eq!(config.momentum, 0.9);
566 assert_eq!(config.max_staleness, 10);
567 }
568
569 #[test]
570 fn test_hogwild_config() {
571 let config = HogwildConfig::default();
572 assert_eq!(config.learning_rate, 1e-3);
573 assert_eq!(config.sparse_ratio, 0.1);
574 assert_eq!(config.max_workers, 4);
575 }
576
577 #[test]
578 fn test_delayed_gradient_config() {
579 let config = DelayedGradientConfig::default();
580 assert_eq!(config.learning_rate, 1e-3);
581 assert_eq!(config.max_delay, 20);
582 assert!(matches!(
583 config.compensation_method,
584 DelayCompensationMethod::LinearDecay
585 ));
586 }
587
588 #[test]
589 fn test_parameter_server_creation() {
590 let params = vec![Tensor::zeros(&[10]).unwrap()];
591 let server = ParameterServer::new(params);
592 assert_eq!(server.get_global_step(), 0);
593 }
594
595 #[test]
596 fn test_elastic_averaging_config() {
597 let config = ElasticAveragingConfig::default();
598 assert_eq!(config.learning_rate, 1e-3);
599 assert_eq!(config.alpha, 0.6);
600 assert_eq!(config.tau, 10);
601 assert_eq!(config.beta, 0.9);
602 }
603}