1#![allow(dead_code)]
9use crate::rpc::{register_function, rpc_async};
10use crate::{TorshDistributedError, TorshResult};
11use dashmap::DashMap;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::{Arc, RwLock};
15use tokio::sync::Mutex;
16use torsh_nn::Parameter;
17use torsh_tensor::Tensor;
18use tracing::{debug, info};
19
20#[derive(Debug, Clone)]
22pub struct ParameterServerConfig {
23 pub learning_rate: f32,
25 pub use_momentum: bool,
27 pub momentum: f32,
29 pub weight_decay: f32,
31 pub max_concurrent_updates: usize,
33 pub gradient_clip_value: Option<f32>,
35}
36
37impl Default for ParameterServerConfig {
38 fn default() -> Self {
39 Self {
40 learning_rate: 0.01,
41 use_momentum: true,
42 momentum: 0.9,
43 weight_decay: 0.0,
44 max_concurrent_updates: 10,
45 gradient_clip_value: None,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub enum ParameterServerMessage {
53 PushGradients {
55 worker_id: u32,
56 gradients: HashMap<String, Vec<f32>>,
57 version: u64,
58 },
59 PullParameters {
61 worker_id: u32,
62 param_names: Vec<String>,
63 },
64 InitializeParameters {
66 parameters: HashMap<String, Vec<f32>>,
67 },
68 GetStats,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum ParameterServerResponse {
75 PushResponse { success: bool, new_version: u64 },
77 PullResponse {
79 parameters: HashMap<String, Vec<f32>>,
80 version: u64,
81 },
82 InitResponse { success: bool },
84 StatsResponse { stats: ParameterServerStats },
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ParameterServerStats {
91 pub num_parameters: usize,
93 pub total_pushes: u64,
95 pub total_pulls: u64,
97 pub current_version: u64,
99 pub active_workers: usize,
101 pub memory_usage_mb: f64,
103}
104
105struct ParameterServerState {
107 parameters: DashMap<String, Arc<RwLock<Tensor>>>,
109 momentum_buffers: DashMap<String, Arc<RwLock<Tensor>>>,
111 version: Arc<RwLock<u64>>,
113 config: ParameterServerConfig,
115 stats: Arc<Mutex<ParameterServerStats>>,
117 active_workers: Arc<RwLock<std::collections::HashSet<u32>>>,
119 gradient_history: Arc<RwLock<Vec<(u32, String, f32)>>>, }
122
123impl ParameterServerState {
124 fn new(config: ParameterServerConfig) -> Self {
125 Self {
126 parameters: DashMap::new(),
127 momentum_buffers: DashMap::new(),
128 version: Arc::new(RwLock::new(0)),
129 config,
130 stats: Arc::new(Mutex::new(ParameterServerStats {
131 num_parameters: 0,
132 total_pushes: 0,
133 total_pulls: 0,
134 current_version: 0,
135 active_workers: 0,
136 memory_usage_mb: 0.0,
137 })),
138 active_workers: Arc::new(RwLock::new(std::collections::HashSet::new())),
139 gradient_history: Arc::new(RwLock::new(Vec::new())),
140 }
141 }
142
143 async fn initialize_parameters(
145 &self,
146 parameters: HashMap<String, Vec<f32>>,
147 ) -> TorshResult<bool> {
148 info!(
149 "Initializing {} parameters on parameter server",
150 parameters.len()
151 );
152
153 for (name, data) in parameters {
154 let shape = vec![data.len()]; let tensor = Tensor::from_vec(data, &shape)?;
156 self.parameters
157 .insert(name.clone(), Arc::new(RwLock::new(tensor)));
158
159 if self.config.use_momentum {
161 let zeros = Tensor::zeros(&shape, torsh_core::DeviceType::Cpu)?;
162 self.momentum_buffers
163 .insert(name, Arc::new(RwLock::new(zeros)));
164 }
165 }
166
167 {
169 let mut stats = self.stats.lock().await;
170 stats.num_parameters = self.parameters.len();
171 stats.current_version = *self.version.read().expect("lock should not be poisoned");
172 }
173
174 Ok(true)
175 }
176
177 async fn push_gradients(
179 &self,
180 worker_id: u32,
181 gradients: HashMap<String, Vec<f32>>,
182 _version: u64,
183 ) -> TorshResult<u64> {
184 debug!(
185 "Received gradients from worker {} for {} parameters",
186 worker_id,
187 gradients.len()
188 );
189
190 {
192 let mut workers = self
193 .active_workers
194 .write()
195 .expect("lock should not be poisoned");
196 workers.insert(worker_id);
197 }
198
199 let mut gradient_norms = Vec::new();
200
201 for (param_name, grad_data) in gradients {
202 if let Some(param_entry) = self.parameters.get(¶m_name) {
203 let param_tensor = param_entry.clone();
204 let mut param_guard = param_tensor.write().expect("lock should not be poisoned");
205
206 let shape = param_guard.shape().dims().to_vec();
208 let grad_tensor = Tensor::from_vec(grad_data, &shape)?;
209
210 let grad_norm = grad_tensor.norm()?.item()?;
212 gradient_norms.push((worker_id, param_name.clone(), grad_norm));
213
214 let clipped_grad = if let Some(clip_value) = self.config.gradient_clip_value {
216 if grad_norm > clip_value {
217 grad_tensor.mul_scalar(clip_value / grad_norm)?
218 } else {
219 grad_tensor
220 }
221 } else {
222 grad_tensor
223 };
224
225 let grad_with_decay = if self.config.weight_decay > 0.0 {
227 let weight_penalty = param_guard.mul_scalar(self.config.weight_decay)?;
228 clipped_grad.add(&weight_penalty)?
229 } else {
230 clipped_grad
231 };
232
233 let update = if self.config.use_momentum {
235 if let Some(momentum_entry) = self.momentum_buffers.get(¶m_name) {
236 let momentum_tensor = momentum_entry.clone();
237 let mut momentum_guard = momentum_tensor
238 .write()
239 .expect("lock should not be poisoned");
240
241 *momentum_guard = momentum_guard
243 .mul_scalar(self.config.momentum)?
244 .add(&grad_with_decay)?;
245 momentum_guard.clone()
246 } else {
247 grad_with_decay
248 }
249 } else {
250 grad_with_decay
251 };
252
253 *param_guard = param_guard.sub(&update.mul_scalar(self.config.learning_rate)?)?;
255 }
256 }
257
258 {
260 let mut history = self
261 .gradient_history
262 .write()
263 .expect("lock should not be poisoned");
264 history.extend(gradient_norms);
265 if history.len() > 1000 {
267 history.drain(0..500);
268 }
269 }
270
271 let new_version = {
273 let mut version = self.version.write().expect("lock should not be poisoned");
274 *version += 1;
275 *version
276 };
277
278 {
280 let mut stats = self.stats.lock().await;
281 stats.total_pushes += 1;
282 stats.current_version = new_version;
283 stats.active_workers = self
284 .active_workers
285 .read()
286 .expect("lock should not be poisoned")
287 .len();
288 stats.memory_usage_mb = (self.parameters.len() * std::mem::size_of::<f32>() * 1000)
290 as f64
291 / (1024.0 * 1024.0);
292 }
293
294 Ok(new_version)
295 }
296
297 async fn pull_parameters(
299 &self,
300 worker_id: u32,
301 param_names: Vec<String>,
302 ) -> TorshResult<(HashMap<String, Vec<f32>>, u64)> {
303 debug!(
304 "Worker {} pulling {} parameters",
305 worker_id,
306 param_names.len()
307 );
308
309 let mut parameters = HashMap::new();
310
311 for param_name in param_names {
312 if let Some(param_entry) = self.parameters.get(¶m_name) {
313 let param_tensor = param_entry.clone();
314 let param_guard = param_tensor.read().expect("lock should not be poisoned");
315
316 let data = param_guard.flatten()?.to_vec()?;
318 parameters.insert(param_name, data);
319 }
320 }
321
322 let version = *self.version.read().expect("lock should not be poisoned");
323
324 {
326 let mut stats = self.stats.lock().await;
327 stats.total_pulls += 1;
328 }
329
330 Ok((parameters, version))
331 }
332
333 async fn get_stats(&self) -> ParameterServerStats {
335 self.stats.lock().await.clone()
336 }
337}
338
339pub struct ParameterServer {
341 state: Arc<ParameterServerState>,
342 server_rank: u32,
343}
344
345impl ParameterServer {
346 pub fn new(server_rank: u32, config: ParameterServerConfig) -> Self {
348 Self {
349 state: Arc::new(ParameterServerState::new(config)),
350 server_rank,
351 }
352 }
353
354 pub async fn start(&self) -> TorshResult<()> {
356 info!("Starting parameter server on rank {}", self.server_rank);
357
358 let _state = self.state.clone();
359
360 register_function("ps_initialize", move |msg: ParameterServerMessage| {
362 match msg {
363 ParameterServerMessage::InitializeParameters {
364 parameters: _parameters,
365 } => {
366 Ok(ParameterServerResponse::InitResponse { success: true })
368 }
369 _ => Err("Invalid message type for ps_initialize".to_string()),
370 }
371 })
372 .await?;
373
374 register_function("ps_push_gradients", move |msg: ParameterServerMessage| {
375 match msg {
376 ParameterServerMessage::PushGradients {
377 worker_id: _,
378 gradients: _,
379 version,
380 } => {
381 Ok(ParameterServerResponse::PushResponse {
383 success: true,
384 new_version: version + 1,
385 })
386 }
387 _ => Err("Invalid message type for ps_push_gradients".to_string()),
388 }
389 })
390 .await?;
391
392 register_function("ps_pull_parameters", move |msg: ParameterServerMessage| {
393 match msg {
394 ParameterServerMessage::PullParameters {
395 worker_id: _,
396 param_names: _,
397 } => {
398 Ok(ParameterServerResponse::PullResponse {
400 parameters: std::collections::HashMap::new(),
401 version: 1,
402 })
403 }
404 _ => Err("Invalid message type for ps_pull_parameters".to_string()),
405 }
406 })
407 .await?;
408
409 register_function("ps_get_stats", move |msg: ParameterServerMessage| {
410 match msg {
411 ParameterServerMessage::GetStats => {
412 let stats = ParameterServerStats {
414 num_parameters: 0,
415 total_pushes: 0,
416 total_pulls: 0,
417 current_version: 1,
418 active_workers: 0,
419 memory_usage_mb: 0.0,
420 };
421 Ok(ParameterServerResponse::StatsResponse { stats })
422 }
423 _ => Err("Invalid message type for ps_get_stats".to_string()),
424 }
425 })
426 .await?;
427
428 info!(
429 "Parameter server started successfully on rank {}",
430 self.server_rank
431 );
432 Ok(())
433 }
434
435 pub async fn get_statistics(&self) -> ParameterServerStats {
437 self.state.get_stats().await
438 }
439
440 pub fn get_version(&self) -> u64 {
442 *self
443 .state
444 .version
445 .read()
446 .expect("lock should not be poisoned")
447 }
448
449 pub fn num_parameters(&self) -> usize {
451 self.state.parameters.len()
452 }
453
454 pub fn has_parameter(&self, name: &str) -> bool {
456 self.state.parameters.contains_key(name)
457 }
458}
459
460pub struct ParameterServerClient {
462 server_rank: u32,
463 worker_id: u32,
464 current_version: Arc<RwLock<u64>>,
465}
466
467impl ParameterServerClient {
468 pub fn new(server_rank: u32, worker_id: u32) -> Self {
470 Self {
471 server_rank,
472 worker_id,
473 current_version: Arc::new(RwLock::new(0)),
474 }
475 }
476
477 pub async fn initialize_parameters(
479 &self,
480 parameters: HashMap<String, Parameter>,
481 ) -> TorshResult<()> {
482 let mut param_data = HashMap::new();
483
484 for (name, param) in parameters {
485 let tensor = param.tensor();
486 let tensor_guard = tensor.read();
487 let data = tensor_guard.flatten()?.to_vec()?;
488 param_data.insert(name, data);
489 }
490
491 let message = ParameterServerMessage::InitializeParameters {
492 parameters: param_data,
493 };
494
495 let response: ParameterServerResponse =
496 rpc_async(self.server_rank, "ps_initialize", message).await?;
497
498 match response {
499 ParameterServerResponse::InitResponse { success } => {
500 if success {
501 info!("Successfully initialized parameters on parameter server");
502 Ok(())
503 } else {
504 Err(TorshDistributedError::backend_error(
505 "parameter_server",
506 "Failed to initialize parameters",
507 ))
508 }
509 }
510 _ => Err(TorshDistributedError::backend_error(
511 "parameter_server",
512 "Unexpected response type",
513 )),
514 }
515 }
516
517 pub async fn push_gradients(&self, gradients: HashMap<String, Tensor>) -> TorshResult<u64> {
519 let mut grad_data = HashMap::new();
520
521 for (name, grad) in gradients {
522 let data = grad.flatten()?.to_vec()?;
523 grad_data.insert(name, data);
524 }
525
526 let current_version = *self
527 .current_version
528 .read()
529 .expect("lock should not be poisoned");
530 let message = ParameterServerMessage::PushGradients {
531 worker_id: self.worker_id,
532 gradients: grad_data,
533 version: current_version,
534 };
535
536 let response: ParameterServerResponse =
537 rpc_async(self.server_rank, "ps_push_gradients", message).await?;
538
539 match response {
540 ParameterServerResponse::PushResponse {
541 success,
542 new_version,
543 } => {
544 if success {
545 *self
546 .current_version
547 .write()
548 .expect("lock should not be poisoned") = new_version;
549 debug!(
550 "Successfully pushed gradients, new version: {}",
551 new_version
552 );
553 Ok(new_version)
554 } else {
555 Err(TorshDistributedError::backend_error(
556 "parameter_server",
557 "Failed to push gradients",
558 ))
559 }
560 }
561 _ => Err(TorshDistributedError::backend_error(
562 "parameter_server",
563 "Unexpected response type",
564 )),
565 }
566 }
567
568 pub async fn pull_parameters(
570 &self,
571 param_names: Vec<String>,
572 ) -> TorshResult<HashMap<String, Tensor>> {
573 let message = ParameterServerMessage::PullParameters {
574 worker_id: self.worker_id,
575 param_names: param_names.clone(),
576 };
577
578 let response: ParameterServerResponse =
579 rpc_async(self.server_rank, "ps_pull_parameters", message).await?;
580
581 match response {
582 ParameterServerResponse::PullResponse {
583 parameters,
584 version,
585 } => {
586 let mut result = HashMap::new();
587
588 for (name, data) in parameters {
589 let shape = vec![data.len()]; let tensor = Tensor::from_vec(data, &shape)?;
591 result.insert(name, tensor);
592 }
593
594 *self
595 .current_version
596 .write()
597 .expect("lock should not be poisoned") = version;
598 debug!(
599 "Successfully pulled {} parameters, version: {}",
600 result.len(),
601 version
602 );
603 Ok(result)
604 }
605 _ => Err(TorshDistributedError::backend_error(
606 "parameter_server",
607 "Unexpected response type",
608 )),
609 }
610 }
611
612 pub async fn get_server_stats(&self) -> TorshResult<ParameterServerStats> {
614 let message = ParameterServerMessage::GetStats;
615 let response: ParameterServerResponse =
616 rpc_async(self.server_rank, "ps_get_stats", message).await?;
617
618 match response {
619 ParameterServerResponse::StatsResponse { stats } => Ok(stats),
620 _ => Err(TorshDistributedError::backend_error(
621 "parameter_server",
622 "Unexpected response type",
623 )),
624 }
625 }
626
627 pub fn get_local_version(&self) -> u64 {
629 *self
630 .current_version
631 .read()
632 .expect("lock should not be poisoned")
633 }
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[tokio::test]
641 async fn test_parameter_server_creation() {
642 let config = ParameterServerConfig::default();
643 let server = ParameterServer::new(0, config);
644
645 assert_eq!(server.server_rank, 0);
646 assert_eq!(server.num_parameters(), 0);
647 assert_eq!(server.get_version(), 0);
648 }
649
650 #[tokio::test]
651 async fn test_parameter_server_config() {
652 let config = ParameterServerConfig {
653 learning_rate: 0.001,
654 use_momentum: false,
655 gradient_clip_value: Some(1.0),
656 ..Default::default()
657 };
658
659 assert_eq!(config.learning_rate, 0.001);
660 assert!(!config.use_momentum);
661 assert_eq!(config.gradient_clip_value, Some(1.0));
662 }
663
664 #[tokio::test]
665 async fn test_parameter_server_client() {
666 let client = ParameterServerClient::new(0, 1);
667
668 assert_eq!(client.server_rank, 0);
669 assert_eq!(client.worker_id, 1);
670 assert_eq!(client.get_local_version(), 0);
671 }
672
673 #[tokio::test]
674 async fn test_parameter_server_stats() {
675 let stats = ParameterServerStats {
676 num_parameters: 100,
677 total_pushes: 50,
678 total_pulls: 30,
679 current_version: 10,
680 active_workers: 3,
681 memory_usage_mb: 128.5,
682 };
683
684 assert_eq!(stats.num_parameters, 100);
685 assert_eq!(stats.total_pushes, 50);
686 assert_eq!(stats.active_workers, 3);
687 assert_eq!(stats.memory_usage_mb, 128.5);
688 }
689
690 #[tokio::test]
691 #[ignore] async fn test_parameter_server_integration() -> TorshResult<()> {
693 let config = ParameterServerConfig::default();
697 let _server = ParameterServer::new(0, config);
698
699 Ok(())
707 }
708}