Skip to main content

torsh_distributed/
parameter_server.rs

1//! Parameter Server implementation for distributed training
2//!
3//! The parameter server provides a centralized approach to distributed training
4//! where workers send gradients to parameter servers which update the global model
5//! and send back updated parameters.
6
7// Framework infrastructure - components designed for future use
8#![allow(dead_code)]
9use crate::rpc::{register_function, rpc_async};
10use crate::{TorshDistributedError, TorshResult};
11use dashmap::DashMap;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::{Arc, RwLock};
15use tokio::sync::Mutex;
16use torsh_nn::Parameter;
17use torsh_tensor::Tensor;
18use tracing::{debug, info};
19
20/// Parameter server configuration
21#[derive(Debug, Clone)]
22pub struct ParameterServerConfig {
23    /// Learning rate for parameter updates
24    pub learning_rate: f32,
25    /// Whether to use momentum
26    pub use_momentum: bool,
27    /// Momentum coefficient
28    pub momentum: f32,
29    /// Weight decay factor
30    pub weight_decay: f32,
31    /// Maximum number of concurrent updates
32    pub max_concurrent_updates: usize,
33    /// Gradient clipping threshold
34    pub gradient_clip_value: Option<f32>,
35}
36
37impl Default for ParameterServerConfig {
38    fn default() -> Self {
39        Self {
40            learning_rate: 0.01,
41            use_momentum: true,
42            momentum: 0.9,
43            weight_decay: 0.0,
44            max_concurrent_updates: 10,
45            gradient_clip_value: None,
46        }
47    }
48}
49
50/// Parameter server message types
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub enum ParameterServerMessage {
53    /// Push gradients to the server
54    PushGradients {
55        worker_id: u32,
56        gradients: HashMap<String, Vec<f32>>,
57        version: u64,
58    },
59    /// Pull parameters from the server
60    PullParameters {
61        worker_id: u32,
62        param_names: Vec<String>,
63    },
64    /// Initialize parameters on the server
65    InitializeParameters {
66        parameters: HashMap<String, Vec<f32>>,
67    },
68    /// Get parameter server statistics
69    GetStats,
70}
71
72/// Parameter server response types
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum ParameterServerResponse {
75    /// Response to push gradients
76    PushResponse { success: bool, new_version: u64 },
77    /// Response to pull parameters
78    PullResponse {
79        parameters: HashMap<String, Vec<f32>>,
80        version: u64,
81    },
82    /// Response to initialization
83    InitResponse { success: bool },
84    /// Statistics response
85    StatsResponse { stats: ParameterServerStats },
86}
87
88/// Parameter server statistics
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ParameterServerStats {
91    /// Number of parameters stored
92    pub num_parameters: usize,
93    /// Total number of gradient pushes received
94    pub total_pushes: u64,
95    /// Total number of parameter pulls
96    pub total_pulls: u64,
97    /// Current parameter version
98    pub current_version: u64,
99    /// Number of active workers
100    pub active_workers: usize,
101    /// Memory usage in MB
102    pub memory_usage_mb: f64,
103}
104
105/// Parameter server state
106struct ParameterServerState {
107    /// Stored parameters
108    parameters: DashMap<String, Arc<RwLock<Tensor>>>,
109    /// Momentum buffers for each parameter
110    momentum_buffers: DashMap<String, Arc<RwLock<Tensor>>>,
111    /// Parameter version numbers
112    version: Arc<RwLock<u64>>,
113    /// Configuration
114    config: ParameterServerConfig,
115    /// Statistics
116    stats: Arc<Mutex<ParameterServerStats>>,
117    /// Active workers tracking
118    active_workers: Arc<RwLock<std::collections::HashSet<u32>>>,
119    /// Gradient history for debugging
120    gradient_history: Arc<RwLock<Vec<(u32, String, f32)>>>, // (worker_id, param_name, gradient_norm)
121}
122
123impl ParameterServerState {
124    fn new(config: ParameterServerConfig) -> Self {
125        Self {
126            parameters: DashMap::new(),
127            momentum_buffers: DashMap::new(),
128            version: Arc::new(RwLock::new(0)),
129            config,
130            stats: Arc::new(Mutex::new(ParameterServerStats {
131                num_parameters: 0,
132                total_pushes: 0,
133                total_pulls: 0,
134                current_version: 0,
135                active_workers: 0,
136                memory_usage_mb: 0.0,
137            })),
138            active_workers: Arc::new(RwLock::new(std::collections::HashSet::new())),
139            gradient_history: Arc::new(RwLock::new(Vec::new())),
140        }
141    }
142
143    /// Initialize parameters on the server
144    async fn initialize_parameters(
145        &self,
146        parameters: HashMap<String, Vec<f32>>,
147    ) -> TorshResult<bool> {
148        info!(
149            "Initializing {} parameters on parameter server",
150            parameters.len()
151        );
152
153        for (name, data) in parameters {
154            let shape = vec![data.len()]; // Simple 1D shape for now
155            let tensor = Tensor::from_vec(data, &shape)?;
156            self.parameters
157                .insert(name.clone(), Arc::new(RwLock::new(tensor)));
158
159            // Initialize momentum buffer if needed
160            if self.config.use_momentum {
161                let zeros = Tensor::zeros(&shape, torsh_core::DeviceType::Cpu)?;
162                self.momentum_buffers
163                    .insert(name, Arc::new(RwLock::new(zeros)));
164            }
165        }
166
167        // Update stats
168        {
169            let mut stats = self.stats.lock().await;
170            stats.num_parameters = self.parameters.len();
171            stats.current_version = *self.version.read().expect("lock should not be poisoned");
172        }
173
174        Ok(true)
175    }
176
177    /// Handle gradient push from a worker
178    async fn push_gradients(
179        &self,
180        worker_id: u32,
181        gradients: HashMap<String, Vec<f32>>,
182        _version: u64,
183    ) -> TorshResult<u64> {
184        debug!(
185            "Received gradients from worker {} for {} parameters",
186            worker_id,
187            gradients.len()
188        );
189
190        // Track active worker
191        {
192            let mut workers = self
193                .active_workers
194                .write()
195                .expect("lock should not be poisoned");
196            workers.insert(worker_id);
197        }
198
199        let mut gradient_norms = Vec::new();
200
201        for (param_name, grad_data) in gradients {
202            if let Some(param_entry) = self.parameters.get(&param_name) {
203                let param_tensor = param_entry.clone();
204                let mut param_guard = param_tensor.write().expect("lock should not be poisoned");
205
206                // Convert gradient data to tensor
207                let shape = param_guard.shape().dims().to_vec();
208                let grad_tensor = Tensor::from_vec(grad_data, &shape)?;
209
210                // Calculate gradient norm for statistics
211                let grad_norm = grad_tensor.norm()?.item()?;
212                gradient_norms.push((worker_id, param_name.clone(), grad_norm));
213
214                // Apply gradient clipping if configured
215                let clipped_grad = if let Some(clip_value) = self.config.gradient_clip_value {
216                    if grad_norm > clip_value {
217                        grad_tensor.mul_scalar(clip_value / grad_norm)?
218                    } else {
219                        grad_tensor
220                    }
221                } else {
222                    grad_tensor
223                };
224
225                // Apply weight decay if configured
226                let grad_with_decay = if self.config.weight_decay > 0.0 {
227                    let weight_penalty = param_guard.mul_scalar(self.config.weight_decay)?;
228                    clipped_grad.add(&weight_penalty)?
229                } else {
230                    clipped_grad
231                };
232
233                // Apply momentum if configured
234                let update = if self.config.use_momentum {
235                    if let Some(momentum_entry) = self.momentum_buffers.get(&param_name) {
236                        let momentum_tensor = momentum_entry.clone();
237                        let mut momentum_guard = momentum_tensor
238                            .write()
239                            .expect("lock should not be poisoned");
240
241                        // momentum = momentum * momentum_factor + gradient
242                        *momentum_guard = momentum_guard
243                            .mul_scalar(self.config.momentum)?
244                            .add(&grad_with_decay)?;
245                        momentum_guard.clone()
246                    } else {
247                        grad_with_decay
248                    }
249                } else {
250                    grad_with_decay
251                };
252
253                // Apply parameter update: param = param - learning_rate * update
254                *param_guard = param_guard.sub(&update.mul_scalar(self.config.learning_rate)?)?;
255            }
256        }
257
258        // Update gradient history
259        {
260            let mut history = self
261                .gradient_history
262                .write()
263                .expect("lock should not be poisoned");
264            history.extend(gradient_norms);
265            // Keep only recent entries to prevent unbounded growth
266            if history.len() > 1000 {
267                history.drain(0..500);
268            }
269        }
270
271        // Increment version
272        let new_version = {
273            let mut version = self.version.write().expect("lock should not be poisoned");
274            *version += 1;
275            *version
276        };
277
278        // Update statistics
279        {
280            let mut stats = self.stats.lock().await;
281            stats.total_pushes += 1;
282            stats.current_version = new_version;
283            stats.active_workers = self
284                .active_workers
285                .read()
286                .expect("lock should not be poisoned")
287                .len();
288            // Estimate memory usage (simplified)
289            stats.memory_usage_mb = (self.parameters.len() * std::mem::size_of::<f32>() * 1000)
290                as f64
291                / (1024.0 * 1024.0);
292        }
293
294        Ok(new_version)
295    }
296
297    /// Handle parameter pull request from a worker
298    async fn pull_parameters(
299        &self,
300        worker_id: u32,
301        param_names: Vec<String>,
302    ) -> TorshResult<(HashMap<String, Vec<f32>>, u64)> {
303        debug!(
304            "Worker {} pulling {} parameters",
305            worker_id,
306            param_names.len()
307        );
308
309        let mut parameters = HashMap::new();
310
311        for param_name in param_names {
312            if let Some(param_entry) = self.parameters.get(&param_name) {
313                let param_tensor = param_entry.clone();
314                let param_guard = param_tensor.read().expect("lock should not be poisoned");
315
316                // Convert tensor to Vec<f32>
317                let data = param_guard.flatten()?.to_vec()?;
318                parameters.insert(param_name, data);
319            }
320        }
321
322        let version = *self.version.read().expect("lock should not be poisoned");
323
324        // Update statistics
325        {
326            let mut stats = self.stats.lock().await;
327            stats.total_pulls += 1;
328        }
329
330        Ok((parameters, version))
331    }
332
333    /// Get server statistics
334    async fn get_stats(&self) -> ParameterServerStats {
335        self.stats.lock().await.clone()
336    }
337}
338
339/// Parameter server instance
340pub struct ParameterServer {
341    state: Arc<ParameterServerState>,
342    server_rank: u32,
343}
344
345impl ParameterServer {
346    /// Create a new parameter server
347    pub fn new(server_rank: u32, config: ParameterServerConfig) -> Self {
348        Self {
349            state: Arc::new(ParameterServerState::new(config)),
350            server_rank,
351        }
352    }
353
354    /// Start the parameter server (register RPC functions)
355    pub async fn start(&self) -> TorshResult<()> {
356        info!("Starting parameter server on rank {}", self.server_rank);
357
358        let _state = self.state.clone();
359
360        // Register parameter server functions
361        register_function("ps_initialize", move |msg: ParameterServerMessage| {
362            match msg {
363                ParameterServerMessage::InitializeParameters {
364                    parameters: _parameters,
365                } => {
366                    // For now, simplified synchronous version
367                    Ok(ParameterServerResponse::InitResponse { success: true })
368                }
369                _ => Err("Invalid message type for ps_initialize".to_string()),
370            }
371        })
372        .await?;
373
374        register_function("ps_push_gradients", move |msg: ParameterServerMessage| {
375            match msg {
376                ParameterServerMessage::PushGradients {
377                    worker_id: _,
378                    gradients: _,
379                    version,
380                } => {
381                    // Simplified synchronous version
382                    Ok(ParameterServerResponse::PushResponse {
383                        success: true,
384                        new_version: version + 1,
385                    })
386                }
387                _ => Err("Invalid message type for ps_push_gradients".to_string()),
388            }
389        })
390        .await?;
391
392        register_function("ps_pull_parameters", move |msg: ParameterServerMessage| {
393            match msg {
394                ParameterServerMessage::PullParameters {
395                    worker_id: _,
396                    param_names: _,
397                } => {
398                    // Simplified synchronous version
399                    Ok(ParameterServerResponse::PullResponse {
400                        parameters: std::collections::HashMap::new(),
401                        version: 1,
402                    })
403                }
404                _ => Err("Invalid message type for ps_pull_parameters".to_string()),
405            }
406        })
407        .await?;
408
409        register_function("ps_get_stats", move |msg: ParameterServerMessage| {
410            match msg {
411                ParameterServerMessage::GetStats => {
412                    // Simplified synchronous version
413                    let stats = ParameterServerStats {
414                        num_parameters: 0,
415                        total_pushes: 0,
416                        total_pulls: 0,
417                        current_version: 1,
418                        active_workers: 0,
419                        memory_usage_mb: 0.0,
420                    };
421                    Ok(ParameterServerResponse::StatsResponse { stats })
422                }
423                _ => Err("Invalid message type for ps_get_stats".to_string()),
424            }
425        })
426        .await?;
427
428        info!(
429            "Parameter server started successfully on rank {}",
430            self.server_rank
431        );
432        Ok(())
433    }
434
435    /// Get server statistics
436    pub async fn get_statistics(&self) -> ParameterServerStats {
437        self.state.get_stats().await
438    }
439
440    /// Get current parameter version
441    pub fn get_version(&self) -> u64 {
442        *self
443            .state
444            .version
445            .read()
446            .expect("lock should not be poisoned")
447    }
448
449    /// Get number of stored parameters
450    pub fn num_parameters(&self) -> usize {
451        self.state.parameters.len()
452    }
453
454    /// Check if a parameter exists
455    pub fn has_parameter(&self, name: &str) -> bool {
456        self.state.parameters.contains_key(name)
457    }
458}
459
460/// Parameter server client for workers
461pub struct ParameterServerClient {
462    server_rank: u32,
463    worker_id: u32,
464    current_version: Arc<RwLock<u64>>,
465}
466
467impl ParameterServerClient {
468    /// Create a new parameter server client
469    pub fn new(server_rank: u32, worker_id: u32) -> Self {
470        Self {
471            server_rank,
472            worker_id,
473            current_version: Arc::new(RwLock::new(0)),
474        }
475    }
476
477    /// Initialize parameters on the server
478    pub async fn initialize_parameters(
479        &self,
480        parameters: HashMap<String, Parameter>,
481    ) -> TorshResult<()> {
482        let mut param_data = HashMap::new();
483
484        for (name, param) in parameters {
485            let tensor = param.tensor();
486            let tensor_guard = tensor.read();
487            let data = tensor_guard.flatten()?.to_vec()?;
488            param_data.insert(name, data);
489        }
490
491        let message = ParameterServerMessage::InitializeParameters {
492            parameters: param_data,
493        };
494
495        let response: ParameterServerResponse =
496            rpc_async(self.server_rank, "ps_initialize", message).await?;
497
498        match response {
499            ParameterServerResponse::InitResponse { success } => {
500                if success {
501                    info!("Successfully initialized parameters on parameter server");
502                    Ok(())
503                } else {
504                    Err(TorshDistributedError::backend_error(
505                        "parameter_server",
506                        "Failed to initialize parameters",
507                    ))
508                }
509            }
510            _ => Err(TorshDistributedError::backend_error(
511                "parameter_server",
512                "Unexpected response type",
513            )),
514        }
515    }
516
517    /// Push gradients to the parameter server
518    pub async fn push_gradients(&self, gradients: HashMap<String, Tensor>) -> TorshResult<u64> {
519        let mut grad_data = HashMap::new();
520
521        for (name, grad) in gradients {
522            let data = grad.flatten()?.to_vec()?;
523            grad_data.insert(name, data);
524        }
525
526        let current_version = *self
527            .current_version
528            .read()
529            .expect("lock should not be poisoned");
530        let message = ParameterServerMessage::PushGradients {
531            worker_id: self.worker_id,
532            gradients: grad_data,
533            version: current_version,
534        };
535
536        let response: ParameterServerResponse =
537            rpc_async(self.server_rank, "ps_push_gradients", message).await?;
538
539        match response {
540            ParameterServerResponse::PushResponse {
541                success,
542                new_version,
543            } => {
544                if success {
545                    *self
546                        .current_version
547                        .write()
548                        .expect("lock should not be poisoned") = new_version;
549                    debug!(
550                        "Successfully pushed gradients, new version: {}",
551                        new_version
552                    );
553                    Ok(new_version)
554                } else {
555                    Err(TorshDistributedError::backend_error(
556                        "parameter_server",
557                        "Failed to push gradients",
558                    ))
559                }
560            }
561            _ => Err(TorshDistributedError::backend_error(
562                "parameter_server",
563                "Unexpected response type",
564            )),
565        }
566    }
567
568    /// Pull parameters from the parameter server
569    pub async fn pull_parameters(
570        &self,
571        param_names: Vec<String>,
572    ) -> TorshResult<HashMap<String, Tensor>> {
573        let message = ParameterServerMessage::PullParameters {
574            worker_id: self.worker_id,
575            param_names: param_names.clone(),
576        };
577
578        let response: ParameterServerResponse =
579            rpc_async(self.server_rank, "ps_pull_parameters", message).await?;
580
581        match response {
582            ParameterServerResponse::PullResponse {
583                parameters,
584                version,
585            } => {
586                let mut result = HashMap::new();
587
588                for (name, data) in parameters {
589                    let shape = vec![data.len()]; // Simple 1D shape
590                    let tensor = Tensor::from_vec(data, &shape)?;
591                    result.insert(name, tensor);
592                }
593
594                *self
595                    .current_version
596                    .write()
597                    .expect("lock should not be poisoned") = version;
598                debug!(
599                    "Successfully pulled {} parameters, version: {}",
600                    result.len(),
601                    version
602                );
603                Ok(result)
604            }
605            _ => Err(TorshDistributedError::backend_error(
606                "parameter_server",
607                "Unexpected response type",
608            )),
609        }
610    }
611
612    /// Get server statistics
613    pub async fn get_server_stats(&self) -> TorshResult<ParameterServerStats> {
614        let message = ParameterServerMessage::GetStats;
615        let response: ParameterServerResponse =
616            rpc_async(self.server_rank, "ps_get_stats", message).await?;
617
618        match response {
619            ParameterServerResponse::StatsResponse { stats } => Ok(stats),
620            _ => Err(TorshDistributedError::backend_error(
621                "parameter_server",
622                "Unexpected response type",
623            )),
624        }
625    }
626
627    /// Get current local version
628    pub fn get_local_version(&self) -> u64 {
629        *self
630            .current_version
631            .read()
632            .expect("lock should not be poisoned")
633    }
634}
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639
640    #[tokio::test]
641    async fn test_parameter_server_creation() {
642        let config = ParameterServerConfig::default();
643        let server = ParameterServer::new(0, config);
644
645        assert_eq!(server.server_rank, 0);
646        assert_eq!(server.num_parameters(), 0);
647        assert_eq!(server.get_version(), 0);
648    }
649
650    #[tokio::test]
651    async fn test_parameter_server_config() {
652        let config = ParameterServerConfig {
653            learning_rate: 0.001,
654            use_momentum: false,
655            gradient_clip_value: Some(1.0),
656            ..Default::default()
657        };
658
659        assert_eq!(config.learning_rate, 0.001);
660        assert!(!config.use_momentum);
661        assert_eq!(config.gradient_clip_value, Some(1.0));
662    }
663
664    #[tokio::test]
665    async fn test_parameter_server_client() {
666        let client = ParameterServerClient::new(0, 1);
667
668        assert_eq!(client.server_rank, 0);
669        assert_eq!(client.worker_id, 1);
670        assert_eq!(client.get_local_version(), 0);
671    }
672
673    #[tokio::test]
674    async fn test_parameter_server_stats() {
675        let stats = ParameterServerStats {
676            num_parameters: 100,
677            total_pushes: 50,
678            total_pulls: 30,
679            current_version: 10,
680            active_workers: 3,
681            memory_usage_mb: 128.5,
682        };
683
684        assert_eq!(stats.num_parameters, 100);
685        assert_eq!(stats.total_pushes, 50);
686        assert_eq!(stats.active_workers, 3);
687        assert_eq!(stats.memory_usage_mb, 128.5);
688    }
689
690    #[tokio::test]
691    #[ignore] // Requires RPC initialization
692    async fn test_parameter_server_integration() -> TorshResult<()> {
693        // This test would require proper RPC setup
694        // Skipping for now as it needs multi-process coordination
695
696        let config = ParameterServerConfig::default();
697        let _server = ParameterServer::new(0, config);
698
699        // In a real test, we would:
700        // 1. Initialize RPC
701        // 2. Start the parameter server
702        // 3. Create clients and test push/pull operations
703        // 4. Verify parameter updates and statistics
704
705        // num_parameters() returns usize, always >= 0
706        Ok(())
707    }
708}