1pub const VERSION: &str = env!("CARGO_PKG_VERSION");
12pub const VERSION_MAJOR: u32 = 0;
13pub const VERSION_MINOR: u32 = 1;
14pub const VERSION_PATCH: u32 = 0;
15
16use thiserror::Error;
17use torsh_core::TorshError;
18
19pub type TorshResult<T> = std::result::Result<T, TorshDistributedError>;
21
22#[derive(Error, Debug)]
24pub enum TorshDistributedError {
25 #[error("Backend not initialized. Please call init_process_group() before performing distributed operations")]
26 BackendNotInitialized,
27
28 #[error("Invalid argument '{arg}': {reason}. Expected: {expected}")]
29 InvalidArgument {
30 arg: String,
31 reason: String,
32 expected: String,
33 },
34
35 #[error("Communication error in operation '{operation}': {cause}. This may be due to network issues, process failures, or backend problems")]
36 CommunicationError { operation: String, cause: String },
37
38 #[error("Backend '{backend}' error: {message}. Check backend configuration and availability")]
39 BackendError { backend: String, message: String },
40
41 #[error("Rank out of bounds: rank {rank} >= world_size {world_size}. Valid ranks are 0 to {}", .world_size - 1)]
42 RankOutOfBounds { rank: u32, world_size: u32 },
43
44 #[error("Feature '{feature}' not available in this build. Enable feature flags: {required_features}")]
45 FeatureNotAvailable {
46 feature: String,
47 required_features: String,
48 },
49
50 #[error(
51 "Process group not found with id '{group_id}'. Available groups: {available_groups:?}"
52 )]
53 ProcessGroupNotFound {
54 group_id: String,
55 available_groups: Vec<String>,
56 },
57
58 #[error("Tensor shape mismatch: expected {expected:?}, got {actual:?}. All tensors in collective operations must have the same shape")]
59 TensorShapeMismatch {
60 expected: Vec<usize>,
61 actual: Vec<usize>,
62 },
63
64 #[error("Timeout after {timeout_secs}s waiting for operation '{operation}'. This may indicate network issues or process failures")]
65 OperationTimeout {
66 operation: String,
67 timeout_secs: u64,
68 },
69
70 #[error("Process {rank} failed during operation '{operation}': {cause}. Consider using fault tolerance features")]
71 ProcessFailure {
72 rank: u32,
73 operation: String,
74 cause: String,
75 },
76
77 #[error("Memory allocation failed: requested {requested_bytes} bytes for '{context}'. Available memory may be insufficient")]
78 MemoryAllocationFailed {
79 requested_bytes: usize,
80 context: String,
81 },
82
83 #[error("Serialization error: {0}")]
84 SerializationError(String),
85
86 #[error("I/O error: {0}")]
87 IoError(String),
88
89 #[error("Internal error: {0}")]
90 InternalError(String),
91
92 #[error("Configuration error: {message}. Check your distributed training configuration")]
93 ConfigurationError { message: String },
94
95 #[error("Checkpoint error: {operation} failed - {cause}. Check filesystem permissions and disk space")]
96 CheckpointError { operation: String, cause: String },
97}
98
99impl TorshDistributedError {
100 pub fn invalid_argument(
102 arg: impl Into<String>,
103 reason: impl Into<String>,
104 expected: impl Into<String>,
105 ) -> Self {
106 Self::InvalidArgument {
107 arg: arg.into(),
108 reason: reason.into(),
109 expected: expected.into(),
110 }
111 }
112
113 pub fn communication_error(operation: impl Into<String>, cause: impl Into<String>) -> Self {
115 Self::CommunicationError {
116 operation: operation.into(),
117 cause: cause.into(),
118 }
119 }
120
121 pub fn backend_error(backend: impl Into<String>, message: impl Into<String>) -> Self {
123 Self::BackendError {
124 backend: backend.into(),
125 message: message.into(),
126 }
127 }
128
129 pub fn feature_not_available(
131 feature: impl Into<String>,
132 required_features: impl Into<String>,
133 ) -> Self {
134 Self::FeatureNotAvailable {
135 feature: feature.into(),
136 required_features: required_features.into(),
137 }
138 }
139
140 pub fn process_group_not_found(
142 group_id: impl Into<String>,
143 available_groups: Vec<String>,
144 ) -> Self {
145 Self::ProcessGroupNotFound {
146 group_id: group_id.into(),
147 available_groups,
148 }
149 }
150
151 pub fn tensor_shape_mismatch(expected: Vec<usize>, actual: Vec<usize>) -> Self {
153 Self::TensorShapeMismatch { expected, actual }
154 }
155
156 pub fn operation_timeout(operation: impl Into<String>, timeout_secs: u64) -> Self {
158 Self::OperationTimeout {
159 operation: operation.into(),
160 timeout_secs,
161 }
162 }
163
164 pub fn process_failure(
166 rank: u32,
167 operation: impl Into<String>,
168 cause: impl Into<String>,
169 ) -> Self {
170 Self::ProcessFailure {
171 rank,
172 operation: operation.into(),
173 cause: cause.into(),
174 }
175 }
176
177 pub fn memory_allocation_failed(requested_bytes: usize, context: impl Into<String>) -> Self {
179 Self::MemoryAllocationFailed {
180 requested_bytes,
181 context: context.into(),
182 }
183 }
184
185 pub fn serialization_error(message: impl Into<String>) -> Self {
187 Self::SerializationError(message.into())
188 }
189
190 pub fn io_error(message: impl Into<String>) -> Self {
192 Self::IoError(message.into())
193 }
194
195 pub fn internal_error(message: impl Into<String>) -> Self {
197 Self::InternalError(message.into())
198 }
199
200 pub fn configuration_error(message: impl Into<String>) -> Self {
202 Self::ConfigurationError {
203 message: message.into(),
204 }
205 }
206
207 pub fn checkpoint_error(operation: impl Into<String>, cause: impl Into<String>) -> Self {
209 Self::CheckpointError {
210 operation: operation.into(),
211 cause: cause.into(),
212 }
213 }
214
215 pub fn not_implemented(feature: impl Into<String>) -> Self {
217 Self::FeatureNotAvailable {
218 feature: feature.into(),
219 required_features: "Not yet implemented".to_string(),
220 }
221 }
222
223 pub fn is_retryable(&self) -> bool {
225 match self {
226 Self::CommunicationError { .. } => true,
227 Self::OperationTimeout { .. } => true,
228 Self::ProcessFailure { .. } => true,
229 Self::MemoryAllocationFailed { .. } => false,
230 Self::BackendNotInitialized => false,
231 Self::InvalidArgument { .. } => false,
232 Self::TensorShapeMismatch { .. } => false,
233 Self::FeatureNotAvailable { .. } => false,
234 Self::ProcessGroupNotFound { .. } => false,
235 Self::SerializationError(_) => false,
236 Self::IoError(_) => true,
237 Self::InternalError(_) => false,
238 Self::ConfigurationError { .. } => false,
239 Self::CheckpointError { .. } => true,
240 Self::BackendError { .. } => true,
241 Self::RankOutOfBounds { .. } => false,
242 }
243 }
244
245 pub fn recovery_suggestions(&self) -> Vec<&'static str> {
247 match self {
248 Self::BackendNotInitialized => vec![
249 "Call init_process_group() before performing distributed operations",
250 "Ensure all processes initialize the backend with the same configuration",
251 ],
252 Self::CommunicationError { .. } => vec![
253 "Check network connectivity between processes",
254 "Verify all processes are running and responsive",
255 "Consider using retry mechanisms",
256 "Check firewall and port configurations",
257 ],
258 Self::OperationTimeout { .. } => vec![
259 "Increase timeout duration",
260 "Check for process failures or network issues",
261 "Verify all processes are participating in the operation",
262 "Consider using asynchronous operations",
263 ],
264 Self::ProcessFailure { .. } => vec![
265 "Enable fault tolerance features",
266 "Check process health and system resources",
267 "Consider using elastic training",
268 "Implement checkpoint/restart mechanisms",
269 ],
270 Self::MemoryAllocationFailed { .. } => vec![
271 "Reduce batch size or model size",
272 "Enable CPU offloading for gradients/parameters",
273 "Use gradient compression",
274 "Check available system memory",
275 ],
276 Self::TensorShapeMismatch { .. } => vec![
277 "Ensure all processes use tensors with identical shapes",
278 "Check data preprocessing and model definitions",
279 "Verify tensor creation is consistent across processes",
280 ],
281 Self::FeatureNotAvailable { .. } => vec![
282 "Rebuild with required feature flags enabled",
283 "Install necessary system dependencies",
284 "Use alternative backends or operations",
285 ],
286 _ => vec![
287 "Check configuration and documentation",
288 "Enable debug logging for more details",
289 "Consider using fallback options",
290 ],
291 }
292 }
293}
294
295impl From<TorshDistributedError> for TorshError {
296 fn from(err: TorshDistributedError) -> Self {
297 TorshError::Other(err.to_string())
298 }
299}
300
301impl From<TorshError> for TorshDistributedError {
302 fn from(err: TorshError) -> Self {
303 TorshDistributedError::InternalError(err.to_string())
304 }
305}
306
307pub mod advanced_monitoring;
308pub mod alerting;
309pub mod backend;
310pub mod bottleneck_detection;
311pub mod collectives;
312pub mod communication;
313pub mod communication_scheduler;
314pub mod dask_integration;
315pub mod ddp;
316pub mod debugging;
317pub mod deepspeed_integration;
318pub mod distributed_memory_optimization;
319pub mod distributed_monitoring;
320pub mod edge_computing;
321pub mod enhanced_benchmarks;
322pub mod enhanced_fault_tolerance;
323pub mod error_recovery;
324pub mod expert_parallelism;
325pub mod fairscale_integration;
326pub mod fault_tolerance;
327pub mod fsdp;
328pub mod gradient_compression;
329pub mod gradient_compression_enhanced;
330pub mod green_computing;
331pub mod horovod_integration;
332pub mod metrics;
333pub mod network_aware_compression;
334pub mod parameter_server;
335pub mod pipeline;
336pub mod process_group;
337pub mod profiling;
338pub mod prometheus_exporter;
339pub mod ray_integration;
340pub mod rdma_support;
341pub mod rpc;
342pub mod store;
343pub mod tensor_parallel;
344pub mod three_d_parallelism;
345pub mod training_analytics_dashboard;
346pub mod visualization;
347pub mod zero_3_cpu_offload;
348
349#[cfg(feature = "nccl")]
350pub mod nccl_ops;
351
352#[cfg(feature = "nccl")]
353pub mod nccl_optimization;
354
355pub use backend::{Backend, BackendType, ReduceOp};
357pub use bottleneck_detection::{
358 init_global_bottleneck_detector, run_global_bottleneck_detection,
359 with_global_bottleneck_detector, Bottleneck, BottleneckDetectionConfig, BottleneckDetector,
360 BottleneckSeverity, BottleneckThresholds, BottleneckType,
361};
362pub use collectives::{
363 all_gather,
364 all_gather_group,
366 all_reduce,
367 all_reduce_group,
368 all_to_all,
369 barrier,
370 barrier_group,
371 broadcast,
372 broadcast_group,
373 bucket_all_reduce,
374 hierarchical_all_reduce,
375 irecv,
376 isend,
377 recv,
378 recv_group,
379 reduce,
380 reduce_group,
381 reduce_scatter,
383 ring_all_reduce,
384 scatter,
385 send,
386 send_group,
387 CommunicationGroup,
389 GroupManager,
390};
391pub use communication::{
392 deserialize_message, deserialize_tensor, retry_with_backoff, serialize_message,
393 serialize_tensor, validate_backend_initialized, validate_rank, with_backend_read,
394 with_backend_write, wrap_communication_error, CommunicationStats, StatsCollector,
395};
396pub use communication_scheduler::{
397 CommunicationOp, CommunicationScheduler, CommunicationTask, Priority, SchedulerConfig,
398 SchedulerStats, SchedulingStrategy,
399};
400pub use dask_integration::{
401 DaskArrayConfig, DaskBagConfig, DaskClusterConfig, DaskClusterType, DaskConfig,
402 DaskDataFrameConfig, DaskDistributedConfig, DaskIntegration, DaskMLConfig, DaskMLSearchMethod,
403 DaskScalingConfig, DaskSchedulerConfig, DaskSecurityConfig, DaskShuffleMethod, DaskStats,
404 DaskWorkerConfig,
405};
406pub use ddp::{
407 BucketConfig, BucketInfo, DistributedDataParallel, GradientSyncStats, OverlapConfig,
408};
409pub use debugging::{
410 get_global_debugger, init_global_debugger, ActiveOperation, CommunicationState, DebugConfig,
411 DebugEvent, DiagnosticResult, DistributedDebugger, LogLevel, ProcessGroupState, ResourceState,
412 SystemStateSnapshot,
413};
414pub use deepspeed_integration::{
415 ActivationCheckpointingConfig, DeepSpeedConfig, DeepSpeedIntegration, DeepSpeedStats,
416 FP16Config, OffloadOptimizerConfig, OffloadParamConfig, ZeroOptimizationConfig, ZeroStage,
417};
418pub use edge_computing::{
419 AdaptiveCommunicationParams, AggregationSchedule, AggregationStrategy,
420 BandwidthAdaptationConfig, BandwidthMonitor, ClientSelectionStrategy, CommunicationManager,
421 ComputeCapability, ConnectionType, DataInfo, DataLimits, DeviceDiscoveryConfig, DeviceLocation,
422 DeviceResources, DeviceStatus, DeviceType, DiscoveryProtocol, EdgeComputingConfig,
423 EdgeComputingManager, EdgeDevice, EdgeOptimizationConfig, FederatedAlgorithm,
424 FederatedLearningConfig, FederatedLearningCoordinator, HierarchicalTrainingConfig,
425 HierarchicalTrainingCoordinator, NetworkInfo, PrivacyConfig, PrivacyLevel, PrivacyManager,
426 PrivacyMechanism, ThermalState, TrainingTier,
427};
428pub use error_recovery::{
429 CircuitBreaker, CircuitBreakerConfig, CircuitBreakerState, FailureDetector, HealthChecker,
430 HealthStatus, RetryConfig, RetryExecutor, RetryStats,
431};
432pub use expert_parallelism::{
433 DistributedExpertManager, ExpertAssignment, ExpertParallelismConfig, ExpertParameters,
434 ExpertRouter, ExpertShardingStrategy, RoutingDecision, RoutingStats,
435};
436pub use fairscale_integration::{
437 FairScaleActivationCheckpointingConfig, FairScaleAutoWrapPolicy, FairScaleBalanceMode,
438 FairScaleCheckpointingStrategy, FairScaleConfig, FairScaleFsdpConfig,
439 FairScaleGradScalerConfig, FairScaleIntegration, FairScaleMemoryOptimizationConfig,
440 FairScaleOssConfig, FairScalePipelineConfig, FairScalePipelineSchedule, FairScaleStats,
441};
442pub use fault_tolerance::{
443 checkpoint_utils, CheckpointConfig, CheckpointManager, DistributedMetadata, ElasticConfig,
444 ElasticTrainingManager, ScalingEvent, ScalingState, TrainingCheckpoint,
445};
446pub use fsdp::{
447 auto_wrap_modules, fsdp_wrap, AutoWrapPolicy, BackwardPrefetch, FsdpConfig,
448 FullyShardedDataParallel, MemoryConfig, MemoryStats, MixedPrecisionConfig,
449 ShardInfo as FsdpShardInfo, ShardingStrategy,
450};
451pub use gradient_compression::{
452 CompressedData, CompressedGradient, CompressionConfig, CompressionMetadata, CompressionMethod,
453 CompressionStats, GradientCompressor,
454};
455pub use green_computing::{
456 CarbonFootprintData, DeviceEnergyData, GreenComputingConfig, GreenComputingManager,
457 GreenOptimizationStrategy, GreenTrainingScheduler, PowerManagementStrategy,
458 RenewableEnergyData, ScheduleAction, SustainabilityMetrics, SustainabilityReport,
459 SustainabilityReportingConfig, TrainingScheduleRecommendation, TrainingWindow,
460};
461pub use horovod_integration::{
462 HorovodCompressionConfig, HorovodCompressionType, HorovodConfig, HorovodElasticConfig,
463 HorovodIntegration, HorovodOptimizerFusionConfig, HorovodStats, HorovodTimelineConfig,
464};
465pub use metrics::{
466 get_global_metrics_collector, init_global_metrics_collector, start_global_metrics_collection,
467 stop_global_metrics_collection, CommunicationMetrics, MetricsCollector, MetricsConfig,
468 PerformanceMetrics, SystemMetrics, TimeSeries, TimeSeriesPoint, TrainingMetrics,
469};
470pub use parameter_server::{
471 ParameterServer, ParameterServerClient, ParameterServerConfig, ParameterServerMessage,
472 ParameterServerResponse, ParameterServerStats,
473};
474pub use pipeline::{
475 create_pipeline_stages, PipelineConfig, PipelineParallel, PipelineStage, PipelineStats,
476 ScheduleType,
477};
478pub use process_group::{ProcessGroup, Rank, WorldSize};
479pub use profiling::{
480 get_global_profiler, init_global_profiler, CommunicationEvent, CommunicationOpType,
481 CommunicationProfiler, OperationStats, ProfilingConfig, ProfilingTimer,
482};
483pub use ray_integration::{
484 RayCheckpointConfig, RayClusterConfig, RayConfig, RayDataConfig, RayDataFormat,
485 RayFailureConfig, RayFaultToleranceConfig, RayIntegration, RayPlacementGroupStrategy,
486 RayResourceConfig, RayRunConfig, RayScalingConfig, RayScheduler, RaySearchAlgorithm,
487 RayServeConfig, RayStats, RayTrainBackend, RayTrainConfig, RayTuneConfig,
488};
489pub use rdma_support::{
490 CompletionStatus, MemoryAccess, MemoryRegion, MemoryRegistration, RdmaConfig,
491 RdmaConnectionManager, RdmaEndpoint, RdmaError, RdmaMemoryPool, RdmaMemoryPoolConfig,
492 RdmaOperation, RdmaProtocol, RdmaQoS, RdmaResult, RdmaStatistics, RdmaTensorScheduler,
493 WorkCompletion, WorkRequest,
494};
495pub use rpc::{
496 delete_rref, get_worker_rank, get_world_size, init_rpc, is_initialized, register_function,
497 remote, rpc_async, shutdown, RRef, RpcBackendOptions, RpcMessage,
498};
499pub use store::{
500 create_store, FileStore, MemoryStore, Store, StoreBackend, StoreConfig, StoreValue,
501};
502pub use tensor_parallel::{
503 ShardInfo as TpShardInfo, TensorParallel, TensorParallelConfig, TensorParallelLayer,
504 TensorParallelStats, TensorParallelStrategy,
505};
506pub use three_d_parallelism::{
507 CommunicationStrategy, LayerShard, LayerType, Memory3DStats, MemoryOptimizationStrategy,
508 ModelShards, Performance3DStats, RankMapping, ThreeDParallelismConfig,
509 ThreeDParallelismCoordinator,
510};
511pub use training_analytics_dashboard::{
512 CommunicationAnalytics, CommunicationHotspot, CommunicationPatterns, ConvergenceAnalytics,
513 DashboardConfig, DashboardExport, EfficiencyAnalytics, OptimizationRecommendation,
514 RecommendationCategory, ResourceBottleneck, ResourceUtilizationAnalytics,
515 SystemHealthAnalytics, TrainingAnalytics, TrainingAnalyticsDashboard,
516 TrainingPerformanceAnalytics, TrainingSummaryReport,
517};
518pub use visualization::{
519 generate_communication_network_html, generate_monitoring_dashboard, Chart, ChartSeries,
520 ChartType, ColorScheme, Dashboard, DashboardLayout, DataPoint, VisualizationConfig,
521 VisualizationGenerator,
522};
523pub use zero_3_cpu_offload::{
524 AutoMemoryStrategy, ConfigModelParameters as ModelParameters, CpuCompressionMethod,
525 Zero3CpuOffloadConfig, Zero3CpuOffloadManager, Zero3MemoryStats, Zero3PerformanceStats,
526};
527
528#[cfg(feature = "nccl")]
529pub use nccl_ops::{
530 nccl_all_gather, nccl_all_reduce, nccl_broadcast, nccl_reduce_scatter, NcclBatch,
531};
532
533#[cfg(feature = "nccl")]
534pub use nccl_optimization::{
535 CudaEvent, CudaStream, FusedNcclOp, FusionStats, GpuMemoryPool, MemoryPoolStats,
536 NcclPerformanceStats, NcclScheduler, OperationStats as NcclOperationStats,
537};
538
539pub async fn init_process_group(
541 backend: BackendType,
542 rank: Rank,
543 world_size: WorldSize,
544 master_addr: &str,
545 master_port: u16,
546) -> TorshResult<ProcessGroup> {
547 ProcessGroup::new(backend, rank, world_size, master_addr, master_port).await
548}
549
550#[allow(unexpected_cfgs)]
552pub fn is_available() -> bool {
553 true
555}
556
557#[allow(unexpected_cfgs)]
559pub fn is_nccl_available() -> bool {
560 cfg!(feature = "nccl") && cfg!(feature = "cuda")
561}
562
563pub fn is_mpi_available() -> bool {
565 cfg!(feature = "mpi")
566}
567
568#[allow(unexpected_cfgs)]
570pub fn is_gloo_available() -> bool {
571 true
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use log::info;
579
580 #[test]
581 fn test_availability() {
582 let available = is_available();
584 info!("Distributed training available: {}", available);
585 }
586}
587
588pub mod prelude {
590 pub use crate::{TorshDistributedError, TorshResult};
591 }