Skip to main content

torsh_distributed/
lib.rs

1//! Distributed training support for ToRSh
2//!
3//! This crate provides distributed training capabilities including:
4//! - Data parallel training (DDP)
5//! - Model parallel training
6//! - Pipeline parallelism
7//! - Collective communication operations
8//! - RPC framework
9
10// Version information
11pub const VERSION: &str = env!("CARGO_PKG_VERSION");
12pub const VERSION_MAJOR: u32 = 0;
13pub const VERSION_MINOR: u32 = 1;
14pub const VERSION_PATCH: u32 = 0;
15
16use thiserror::Error;
17use torsh_core::TorshError;
18
19/// Type alias for Results with TorshDistributedError
20pub type TorshResult<T> = std::result::Result<T, TorshDistributedError>;
21
22/// Distributed training specific errors with detailed context
23#[derive(Error, Debug)]
24pub enum TorshDistributedError {
25    #[error("Backend not initialized. Please call init_process_group() before performing distributed operations")]
26    BackendNotInitialized,
27
28    #[error("Invalid argument '{arg}': {reason}. Expected: {expected}")]
29    InvalidArgument {
30        arg: String,
31        reason: String,
32        expected: String,
33    },
34
35    #[error("Communication error in operation '{operation}': {cause}. This may be due to network issues, process failures, or backend problems")]
36    CommunicationError { operation: String, cause: String },
37
38    #[error("Backend '{backend}' error: {message}. Check backend configuration and availability")]
39    BackendError { backend: String, message: String },
40
41    #[error("Rank out of bounds: rank {rank} >= world_size {world_size}. Valid ranks are 0 to {}", .world_size - 1)]
42    RankOutOfBounds { rank: u32, world_size: u32 },
43
44    #[error("Feature '{feature}' not available in this build. Enable feature flags: {required_features}")]
45    FeatureNotAvailable {
46        feature: String,
47        required_features: String,
48    },
49
50    #[error(
51        "Process group not found with id '{group_id}'. Available groups: {available_groups:?}"
52    )]
53    ProcessGroupNotFound {
54        group_id: String,
55        available_groups: Vec<String>,
56    },
57
58    #[error("Tensor shape mismatch: expected {expected:?}, got {actual:?}. All tensors in collective operations must have the same shape")]
59    TensorShapeMismatch {
60        expected: Vec<usize>,
61        actual: Vec<usize>,
62    },
63
64    #[error("Timeout after {timeout_secs}s waiting for operation '{operation}'. This may indicate network issues or process failures")]
65    OperationTimeout {
66        operation: String,
67        timeout_secs: u64,
68    },
69
70    #[error("Process {rank} failed during operation '{operation}': {cause}. Consider using fault tolerance features")]
71    ProcessFailure {
72        rank: u32,
73        operation: String,
74        cause: String,
75    },
76
77    #[error("Memory allocation failed: requested {requested_bytes} bytes for '{context}'. Available memory may be insufficient")]
78    MemoryAllocationFailed {
79        requested_bytes: usize,
80        context: String,
81    },
82
83    #[error("Serialization error: {0}")]
84    SerializationError(String),
85
86    #[error("I/O error: {0}")]
87    IoError(String),
88
89    #[error("Internal error: {0}")]
90    InternalError(String),
91
92    #[error("Configuration error: {message}. Check your distributed training configuration")]
93    ConfigurationError { message: String },
94
95    #[error("Checkpoint error: {operation} failed - {cause}. Check filesystem permissions and disk space")]
96    CheckpointError { operation: String, cause: String },
97}
98
99impl TorshDistributedError {
100    /// Create an invalid argument error with context
101    pub fn invalid_argument(
102        arg: impl Into<String>,
103        reason: impl Into<String>,
104        expected: impl Into<String>,
105    ) -> Self {
106        Self::InvalidArgument {
107            arg: arg.into(),
108            reason: reason.into(),
109            expected: expected.into(),
110        }
111    }
112
113    /// Create a communication error with operation context
114    pub fn communication_error(operation: impl Into<String>, cause: impl Into<String>) -> Self {
115        Self::CommunicationError {
116            operation: operation.into(),
117            cause: cause.into(),
118        }
119    }
120
121    /// Create a backend error with backend type
122    pub fn backend_error(backend: impl Into<String>, message: impl Into<String>) -> Self {
123        Self::BackendError {
124            backend: backend.into(),
125            message: message.into(),
126        }
127    }
128
129    /// Create a feature not available error with required features
130    pub fn feature_not_available(
131        feature: impl Into<String>,
132        required_features: impl Into<String>,
133    ) -> Self {
134        Self::FeatureNotAvailable {
135            feature: feature.into(),
136            required_features: required_features.into(),
137        }
138    }
139
140    /// Create a process group not found error
141    pub fn process_group_not_found(
142        group_id: impl Into<String>,
143        available_groups: Vec<String>,
144    ) -> Self {
145        Self::ProcessGroupNotFound {
146            group_id: group_id.into(),
147            available_groups,
148        }
149    }
150
151    /// Create a tensor shape mismatch error
152    pub fn tensor_shape_mismatch(expected: Vec<usize>, actual: Vec<usize>) -> Self {
153        Self::TensorShapeMismatch { expected, actual }
154    }
155
156    /// Create an operation timeout error
157    pub fn operation_timeout(operation: impl Into<String>, timeout_secs: u64) -> Self {
158        Self::OperationTimeout {
159            operation: operation.into(),
160            timeout_secs,
161        }
162    }
163
164    /// Create a process failure error
165    pub fn process_failure(
166        rank: u32,
167        operation: impl Into<String>,
168        cause: impl Into<String>,
169    ) -> Self {
170        Self::ProcessFailure {
171            rank,
172            operation: operation.into(),
173            cause: cause.into(),
174        }
175    }
176
177    /// Create a memory allocation failure error
178    pub fn memory_allocation_failed(requested_bytes: usize, context: impl Into<String>) -> Self {
179        Self::MemoryAllocationFailed {
180            requested_bytes,
181            context: context.into(),
182        }
183    }
184
185    /// Create a serialization error
186    pub fn serialization_error(message: impl Into<String>) -> Self {
187        Self::SerializationError(message.into())
188    }
189
190    /// Create an I/O error
191    pub fn io_error(message: impl Into<String>) -> Self {
192        Self::IoError(message.into())
193    }
194
195    /// Create an internal error
196    pub fn internal_error(message: impl Into<String>) -> Self {
197        Self::InternalError(message.into())
198    }
199
200    /// Create a configuration error
201    pub fn configuration_error(message: impl Into<String>) -> Self {
202        Self::ConfigurationError {
203            message: message.into(),
204        }
205    }
206
207    /// Create a checkpoint error
208    pub fn checkpoint_error(operation: impl Into<String>, cause: impl Into<String>) -> Self {
209        Self::CheckpointError {
210            operation: operation.into(),
211            cause: cause.into(),
212        }
213    }
214
215    /// Create a not implemented error
216    pub fn not_implemented(feature: impl Into<String>) -> Self {
217        Self::FeatureNotAvailable {
218            feature: feature.into(),
219            required_features: "Not yet implemented".to_string(),
220        }
221    }
222
223    /// Check if this error is retryable
224    pub fn is_retryable(&self) -> bool {
225        match self {
226            Self::CommunicationError { .. } => true,
227            Self::OperationTimeout { .. } => true,
228            Self::ProcessFailure { .. } => true,
229            Self::MemoryAllocationFailed { .. } => false,
230            Self::BackendNotInitialized => false,
231            Self::InvalidArgument { .. } => false,
232            Self::TensorShapeMismatch { .. } => false,
233            Self::FeatureNotAvailable { .. } => false,
234            Self::ProcessGroupNotFound { .. } => false,
235            Self::SerializationError(_) => false,
236            Self::IoError(_) => true,
237            Self::InternalError(_) => false,
238            Self::ConfigurationError { .. } => false,
239            Self::CheckpointError { .. } => true,
240            Self::BackendError { .. } => true,
241            Self::RankOutOfBounds { .. } => false,
242        }
243    }
244
245    /// Get suggested recovery actions for this error
246    pub fn recovery_suggestions(&self) -> Vec<&'static str> {
247        match self {
248            Self::BackendNotInitialized => vec![
249                "Call init_process_group() before performing distributed operations",
250                "Ensure all processes initialize the backend with the same configuration",
251            ],
252            Self::CommunicationError { .. } => vec![
253                "Check network connectivity between processes",
254                "Verify all processes are running and responsive",
255                "Consider using retry mechanisms",
256                "Check firewall and port configurations",
257            ],
258            Self::OperationTimeout { .. } => vec![
259                "Increase timeout duration",
260                "Check for process failures or network issues",
261                "Verify all processes are participating in the operation",
262                "Consider using asynchronous operations",
263            ],
264            Self::ProcessFailure { .. } => vec![
265                "Enable fault tolerance features",
266                "Check process health and system resources",
267                "Consider using elastic training",
268                "Implement checkpoint/restart mechanisms",
269            ],
270            Self::MemoryAllocationFailed { .. } => vec![
271                "Reduce batch size or model size",
272                "Enable CPU offloading for gradients/parameters",
273                "Use gradient compression",
274                "Check available system memory",
275            ],
276            Self::TensorShapeMismatch { .. } => vec![
277                "Ensure all processes use tensors with identical shapes",
278                "Check data preprocessing and model definitions",
279                "Verify tensor creation is consistent across processes",
280            ],
281            Self::FeatureNotAvailable { .. } => vec![
282                "Rebuild with required feature flags enabled",
283                "Install necessary system dependencies",
284                "Use alternative backends or operations",
285            ],
286            _ => vec![
287                "Check configuration and documentation",
288                "Enable debug logging for more details",
289                "Consider using fallback options",
290            ],
291        }
292    }
293}
294
295impl From<TorshDistributedError> for TorshError {
296    fn from(err: TorshDistributedError) -> Self {
297        TorshError::Other(err.to_string())
298    }
299}
300
301impl From<TorshError> for TorshDistributedError {
302    fn from(err: TorshError) -> Self {
303        TorshDistributedError::InternalError(err.to_string())
304    }
305}
306
307pub mod advanced_monitoring;
308pub mod alerting;
309pub mod backend;
310pub mod bottleneck_detection;
311pub mod collectives;
312pub mod communication;
313pub mod communication_scheduler;
314pub mod dask_integration;
315pub mod ddp;
316pub mod debugging;
317pub mod deepspeed_integration;
318pub mod distributed_memory_optimization;
319pub mod distributed_monitoring;
320pub mod edge_computing;
321pub mod enhanced_benchmarks;
322pub mod enhanced_fault_tolerance;
323pub mod error_recovery;
324pub mod expert_parallelism;
325pub mod fairscale_integration;
326pub mod fault_tolerance;
327pub mod fsdp;
328pub mod gradient_compression;
329pub mod gradient_compression_enhanced;
330pub mod green_computing;
331pub mod horovod_integration;
332pub mod metrics;
333pub mod network_aware_compression;
334pub mod parameter_server;
335pub mod pipeline;
336pub mod process_group;
337pub mod profiling;
338pub mod prometheus_exporter;
339pub mod ray_integration;
340pub mod rdma_support;
341pub mod rpc;
342pub mod store;
343pub mod tensor_parallel;
344pub mod three_d_parallelism;
345pub mod training_analytics_dashboard;
346pub mod visualization;
347pub mod zero_3_cpu_offload;
348
349#[cfg(feature = "nccl")]
350pub mod nccl_ops;
351
352#[cfg(feature = "nccl")]
353pub mod nccl_optimization;
354
355// Re-exports
356pub use backend::{Backend, BackendType, ReduceOp};
357pub use bottleneck_detection::{
358    init_global_bottleneck_detector, run_global_bottleneck_detection,
359    with_global_bottleneck_detector, Bottleneck, BottleneckDetectionConfig, BottleneckDetector,
360    BottleneckSeverity, BottleneckThresholds, BottleneckType,
361};
362pub use collectives::{
363    all_gather,
364    // Group-aware operations
365    all_gather_group,
366    all_reduce,
367    all_reduce_group,
368    all_to_all,
369    barrier,
370    barrier_group,
371    broadcast,
372    broadcast_group,
373    bucket_all_reduce,
374    hierarchical_all_reduce,
375    irecv,
376    isend,
377    recv,
378    recv_group,
379    reduce,
380    reduce_group,
381    // Custom collective operations
382    reduce_scatter,
383    ring_all_reduce,
384    scatter,
385    send,
386    send_group,
387    // Communication group management
388    CommunicationGroup,
389    GroupManager,
390};
391pub use communication::{
392    deserialize_message, deserialize_tensor, retry_with_backoff, serialize_message,
393    serialize_tensor, validate_backend_initialized, validate_rank, with_backend_read,
394    with_backend_write, wrap_communication_error, CommunicationStats, StatsCollector,
395};
396pub use communication_scheduler::{
397    CommunicationOp, CommunicationScheduler, CommunicationTask, Priority, SchedulerConfig,
398    SchedulerStats, SchedulingStrategy,
399};
400pub use dask_integration::{
401    DaskArrayConfig, DaskBagConfig, DaskClusterConfig, DaskClusterType, DaskConfig,
402    DaskDataFrameConfig, DaskDistributedConfig, DaskIntegration, DaskMLConfig, DaskMLSearchMethod,
403    DaskScalingConfig, DaskSchedulerConfig, DaskSecurityConfig, DaskShuffleMethod, DaskStats,
404    DaskWorkerConfig,
405};
406pub use ddp::{
407    BucketConfig, BucketInfo, DistributedDataParallel, GradientSyncStats, OverlapConfig,
408};
409pub use debugging::{
410    get_global_debugger, init_global_debugger, ActiveOperation, CommunicationState, DebugConfig,
411    DebugEvent, DiagnosticResult, DistributedDebugger, LogLevel, ProcessGroupState, ResourceState,
412    SystemStateSnapshot,
413};
414pub use deepspeed_integration::{
415    ActivationCheckpointingConfig, DeepSpeedConfig, DeepSpeedIntegration, DeepSpeedStats,
416    FP16Config, OffloadOptimizerConfig, OffloadParamConfig, ZeroOptimizationConfig, ZeroStage,
417};
418pub use edge_computing::{
419    AdaptiveCommunicationParams, AggregationSchedule, AggregationStrategy,
420    BandwidthAdaptationConfig, BandwidthMonitor, ClientSelectionStrategy, CommunicationManager,
421    ComputeCapability, ConnectionType, DataInfo, DataLimits, DeviceDiscoveryConfig, DeviceLocation,
422    DeviceResources, DeviceStatus, DeviceType, DiscoveryProtocol, EdgeComputingConfig,
423    EdgeComputingManager, EdgeDevice, EdgeOptimizationConfig, FederatedAlgorithm,
424    FederatedLearningConfig, FederatedLearningCoordinator, HierarchicalTrainingConfig,
425    HierarchicalTrainingCoordinator, NetworkInfo, PrivacyConfig, PrivacyLevel, PrivacyManager,
426    PrivacyMechanism, ThermalState, TrainingTier,
427};
428pub use error_recovery::{
429    CircuitBreaker, CircuitBreakerConfig, CircuitBreakerState, FailureDetector, HealthChecker,
430    HealthStatus, RetryConfig, RetryExecutor, RetryStats,
431};
432pub use expert_parallelism::{
433    DistributedExpertManager, ExpertAssignment, ExpertParallelismConfig, ExpertParameters,
434    ExpertRouter, ExpertShardingStrategy, RoutingDecision, RoutingStats,
435};
436pub use fairscale_integration::{
437    FairScaleActivationCheckpointingConfig, FairScaleAutoWrapPolicy, FairScaleBalanceMode,
438    FairScaleCheckpointingStrategy, FairScaleConfig, FairScaleFsdpConfig,
439    FairScaleGradScalerConfig, FairScaleIntegration, FairScaleMemoryOptimizationConfig,
440    FairScaleOssConfig, FairScalePipelineConfig, FairScalePipelineSchedule, FairScaleStats,
441};
442pub use fault_tolerance::{
443    checkpoint_utils, CheckpointConfig, CheckpointManager, DistributedMetadata, ElasticConfig,
444    ElasticTrainingManager, ScalingEvent, ScalingState, TrainingCheckpoint,
445};
446pub use fsdp::{
447    auto_wrap_modules, fsdp_wrap, AutoWrapPolicy, BackwardPrefetch, FsdpConfig,
448    FullyShardedDataParallel, MemoryConfig, MemoryStats, MixedPrecisionConfig,
449    ShardInfo as FsdpShardInfo, ShardingStrategy,
450};
451pub use gradient_compression::{
452    CompressedData, CompressedGradient, CompressionConfig, CompressionMetadata, CompressionMethod,
453    CompressionStats, GradientCompressor,
454};
455pub use green_computing::{
456    CarbonFootprintData, DeviceEnergyData, GreenComputingConfig, GreenComputingManager,
457    GreenOptimizationStrategy, GreenTrainingScheduler, PowerManagementStrategy,
458    RenewableEnergyData, ScheduleAction, SustainabilityMetrics, SustainabilityReport,
459    SustainabilityReportingConfig, TrainingScheduleRecommendation, TrainingWindow,
460};
461pub use horovod_integration::{
462    HorovodCompressionConfig, HorovodCompressionType, HorovodConfig, HorovodElasticConfig,
463    HorovodIntegration, HorovodOptimizerFusionConfig, HorovodStats, HorovodTimelineConfig,
464};
465pub use metrics::{
466    get_global_metrics_collector, init_global_metrics_collector, start_global_metrics_collection,
467    stop_global_metrics_collection, CommunicationMetrics, MetricsCollector, MetricsConfig,
468    PerformanceMetrics, SystemMetrics, TimeSeries, TimeSeriesPoint, TrainingMetrics,
469};
470pub use parameter_server::{
471    ParameterServer, ParameterServerClient, ParameterServerConfig, ParameterServerMessage,
472    ParameterServerResponse, ParameterServerStats,
473};
474pub use pipeline::{
475    create_pipeline_stages, PipelineConfig, PipelineParallel, PipelineStage, PipelineStats,
476    ScheduleType,
477};
478pub use process_group::{ProcessGroup, Rank, WorldSize};
479pub use profiling::{
480    get_global_profiler, init_global_profiler, CommunicationEvent, CommunicationOpType,
481    CommunicationProfiler, OperationStats, ProfilingConfig, ProfilingTimer,
482};
483pub use ray_integration::{
484    RayCheckpointConfig, RayClusterConfig, RayConfig, RayDataConfig, RayDataFormat,
485    RayFailureConfig, RayFaultToleranceConfig, RayIntegration, RayPlacementGroupStrategy,
486    RayResourceConfig, RayRunConfig, RayScalingConfig, RayScheduler, RaySearchAlgorithm,
487    RayServeConfig, RayStats, RayTrainBackend, RayTrainConfig, RayTuneConfig,
488};
489pub use rdma_support::{
490    CompletionStatus, MemoryAccess, MemoryRegion, MemoryRegistration, RdmaConfig,
491    RdmaConnectionManager, RdmaEndpoint, RdmaError, RdmaMemoryPool, RdmaMemoryPoolConfig,
492    RdmaOperation, RdmaProtocol, RdmaQoS, RdmaResult, RdmaStatistics, RdmaTensorScheduler,
493    WorkCompletion, WorkRequest,
494};
495pub use rpc::{
496    delete_rref, get_worker_rank, get_world_size, init_rpc, is_initialized, register_function,
497    remote, rpc_async, shutdown, RRef, RpcBackendOptions, RpcMessage,
498};
499pub use store::{
500    create_store, FileStore, MemoryStore, Store, StoreBackend, StoreConfig, StoreValue,
501};
502pub use tensor_parallel::{
503    ShardInfo as TpShardInfo, TensorParallel, TensorParallelConfig, TensorParallelLayer,
504    TensorParallelStats, TensorParallelStrategy,
505};
506pub use three_d_parallelism::{
507    CommunicationStrategy, LayerShard, LayerType, Memory3DStats, MemoryOptimizationStrategy,
508    ModelShards, Performance3DStats, RankMapping, ThreeDParallelismConfig,
509    ThreeDParallelismCoordinator,
510};
511pub use training_analytics_dashboard::{
512    CommunicationAnalytics, CommunicationHotspot, CommunicationPatterns, ConvergenceAnalytics,
513    DashboardConfig, DashboardExport, EfficiencyAnalytics, OptimizationRecommendation,
514    RecommendationCategory, ResourceBottleneck, ResourceUtilizationAnalytics,
515    SystemHealthAnalytics, TrainingAnalytics, TrainingAnalyticsDashboard,
516    TrainingPerformanceAnalytics, TrainingSummaryReport,
517};
518pub use visualization::{
519    generate_communication_network_html, generate_monitoring_dashboard, Chart, ChartSeries,
520    ChartType, ColorScheme, Dashboard, DashboardLayout, DataPoint, VisualizationConfig,
521    VisualizationGenerator,
522};
523pub use zero_3_cpu_offload::{
524    AutoMemoryStrategy, ConfigModelParameters as ModelParameters, CpuCompressionMethod,
525    Zero3CpuOffloadConfig, Zero3CpuOffloadManager, Zero3MemoryStats, Zero3PerformanceStats,
526};
527
528#[cfg(feature = "nccl")]
529pub use nccl_ops::{
530    nccl_all_gather, nccl_all_reduce, nccl_broadcast, nccl_reduce_scatter, NcclBatch,
531};
532
533#[cfg(feature = "nccl")]
534pub use nccl_optimization::{
535    CudaEvent, CudaStream, FusedNcclOp, FusionStats, GpuMemoryPool, MemoryPoolStats,
536    NcclPerformanceStats, NcclScheduler, OperationStats as NcclOperationStats,
537};
538
539/// Initialize the distributed process group
540pub async fn init_process_group(
541    backend: BackendType,
542    rank: Rank,
543    world_size: WorldSize,
544    master_addr: &str,
545    master_port: u16,
546) -> TorshResult<ProcessGroup> {
547    ProcessGroup::new(backend, rank, world_size, master_addr, master_port).await
548}
549
550/// Check if distributed training is available
551#[allow(unexpected_cfgs)]
552pub fn is_available() -> bool {
553    // Always return true since we have MockBackend available
554    true
555}
556
557/// Check if NCCL backend is available
558#[allow(unexpected_cfgs)]
559pub fn is_nccl_available() -> bool {
560    cfg!(feature = "nccl") && cfg!(feature = "cuda")
561}
562
563/// Check if MPI backend is available  
564pub fn is_mpi_available() -> bool {
565    cfg!(feature = "mpi")
566}
567
568/// Check if Gloo backend is available
569#[allow(unexpected_cfgs)]
570pub fn is_gloo_available() -> bool {
571    // Mock backend pretends to be Gloo
572    true
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use log::info;
579
580    #[test]
581    fn test_availability() {
582        // At least one backend should be available
583        let available = is_available();
584        info!("Distributed training available: {}", available);
585    }
586}
587
588/// Prelude module for convenient imports
589pub mod prelude {
590    pub use crate::{TorshDistributedError, TorshResult};
591    // Re-export public items from modules when they exist
592}