1use crate::{TorshDistributedError, TorshResult};
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::path::Path;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct RayConfig {
22 pub cluster: Option<RayClusterConfig>,
24 pub train: Option<RayTrainConfig>,
26 pub tune: Option<RayTuneConfig>,
28 pub serve: Option<RayServeConfig>,
30 pub data: Option<RayDataConfig>,
32 pub resources: Option<RayResourceConfig>,
34 pub fault_tolerance: Option<RayFaultToleranceConfig>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct RayClusterConfig {
41 pub address: Option<String>,
43 pub redis_address: Option<String>,
45 pub num_cpus: Option<u32>,
47 pub num_gpus: Option<u32>,
49 pub memory_gb: Option<f32>,
51 pub object_store_memory_gb: Option<f32>,
53 pub namespace: Option<String>,
55 pub dashboard_host: Option<String>,
57 pub dashboard_port: Option<u16>,
59 pub include_dashboard: Option<bool>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct RayTrainConfig {
66 pub backend: RayTrainBackend,
68 pub num_workers: u32,
70 pub use_gpu: Option<bool>,
72 pub resources_per_worker: Option<HashMap<String, f32>>,
74 pub placement_group_strategy: Option<RayPlacementGroupStrategy>,
76 pub scaling_config: Option<RayScalingConfig>,
78 pub run_config: Option<RayRunConfig>,
80 pub checkpoint_config: Option<RayCheckpointConfig>,
82 pub failure_config: Option<RayFailureConfig>,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
88pub enum RayTrainBackend {
89 Torch,
91 TensorFlow,
93 Horovod,
95 MPI,
97 Custom,
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
103pub enum RayPlacementGroupStrategy {
104 StrictPack,
106 Pack,
108 StrictSpread,
110 Spread,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct RayScalingConfig {
117 pub num_workers: Option<u32>,
119 pub use_gpu: Option<bool>,
121 pub resources_per_worker: Option<HashMap<String, f32>>,
123 pub placement_group_strategy: Option<RayPlacementGroupStrategy>,
125 pub trainer_resources: Option<HashMap<String, f32>>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct RayRunConfig {
132 pub name: Option<String>,
134 pub storage_path: Option<String>,
136 pub stop: Option<HashMap<String, f32>>,
138 pub checkpoint_freq: Option<u32>,
140 pub keep_checkpoints_num: Option<u32>,
142 pub checkpoint_score_attr: Option<String>,
144 pub checkpoint_mode: Option<RayCheckpointMode>,
146 pub verbose: Option<u32>,
148 pub progress_reporter: Option<RayProgressReporter>,
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
154pub enum RayCheckpointMode {
155 Max,
157 Min,
159}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
163pub enum RayProgressReporter {
164 Default,
166 Json,
168 TensorBoard,
170 WandB,
172 MLflow,
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct RayCheckpointConfig {
179 pub num_to_keep: Option<u32>,
181 pub checkpoint_frequency: Option<u32>,
183 pub checkpoint_at_end: Option<bool>,
185 pub checkpoint_score_attribute: Option<String>,
187 pub checkpoint_mode: Option<RayCheckpointMode>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct RayFailureConfig {
194 pub max_failures: Option<u32>,
196 pub failure_handling: Option<RayFailureHandling>,
198}
199
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
202pub enum RayFailureHandling {
203 Restart,
205 Ignore,
207 Fail,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct RayTuneConfig {
214 pub search_alg: Option<RaySearchAlgorithm>,
216 pub scheduler: Option<RayScheduler>,
218 pub num_samples: Option<u32>,
220 pub max_concurrent_trials: Option<u32>,
222 pub resources_per_trial: Option<HashMap<String, f32>>,
224 pub param_space: Option<HashMap<String, serde_json::Value>>,
226 pub metric: Option<String>,
228 pub mode: Option<String>,
230 pub time_budget_s: Option<f32>,
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
236pub enum RaySearchAlgorithm {
237 BasicVariant,
239 Random,
241 Grid,
243 BayesOpt,
245 Hyperband,
247 BOHB,
249 PopulationBasedTraining,
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255pub enum RayScheduler {
256 FIFO,
258 Hyperband,
260 ASHA,
262 MedianStopping,
264 PopulationBasedTraining,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct RayServeConfig {
271 pub http_options: Option<RayServeHttpOptions>,
273 pub grpc_options: Option<RayServeGrpcOptions>,
275 pub deployments: Option<Vec<RayServeDeploymentConfig>>,
277}
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct RayServeHttpOptions {
282 pub host: Option<String>,
284 pub port: Option<u16>,
286 pub root_path: Option<String>,
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct RayServeGrpcOptions {
293 pub port: Option<u16>,
295 pub grpc_servicer_functions: Option<Vec<String>>,
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct RayServeDeploymentConfig {
302 pub name: String,
304 pub num_replicas: Option<u32>,
306 pub ray_actor_options: Option<HashMap<String, serde_json::Value>>,
308 pub user_config: Option<HashMap<String, serde_json::Value>>,
310 pub max_concurrent_queries: Option<u32>,
312 pub autoscaling_config: Option<RayServeAutoscalingConfig>,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct RayServeAutoscalingConfig {
319 pub min_replicas: Option<u32>,
321 pub max_replicas: Option<u32>,
323 pub target_num_ongoing_requests_per_replica: Option<f32>,
325 pub metrics_interval_s: Option<f32>,
327 pub look_back_period_s: Option<f32>,
329 pub smoothing_factor: Option<f32>,
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize)]
335pub struct RayDataConfig {
336 pub format: Option<RayDataFormat>,
338 pub parallelism: Option<u32>,
340 pub batch_size: Option<u32>,
342 pub prefetch: Option<u32>,
344 pub shuffle: Option<bool>,
346 pub shuffle_buffer_size: Option<u32>,
348}
349
350#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
352pub enum RayDataFormat {
353 Parquet,
355 CSV,
357 JSON,
359 Image,
361 Text,
363 Arrow,
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct RayResourceConfig {
370 pub num_cpus: Option<f32>,
372 pub num_gpus: Option<f32>,
374 pub memory: Option<u64>,
376 pub object_store_memory: Option<u64>,
378 pub custom_resources: Option<HashMap<String, f32>>,
380}
381
382#[derive(Debug, Clone, Serialize, Deserialize)]
384pub struct RayFaultToleranceConfig {
385 pub max_restarts: Option<u32>,
387 pub restart_delay_s: Option<f32>,
389 pub health_check_interval_s: Option<f32>,
391 pub enabled: Option<bool>,
393}
394
395#[derive(Debug, Clone, Default)]
397pub struct RayStats {
398 pub training_runs: u64,
400 pub training_time_sec: f64,
402 pub tuning_trials: u64,
404 pub tuning_time_sec: f64,
406 pub served_requests: u64,
408 pub data_processing_tasks: u64,
410 pub worker_failures: u64,
412 pub restarts: u64,
414 pub resource_utilization: f64,
416 pub checkpoint_frequency: f64,
418}
419
420pub struct RayIntegration {
422 config: RayConfig,
424 stats: RayStats,
426 initialized: bool,
428 rank: u32,
430 world_size: u32,
432 local_rank: u32,
434 local_size: u32,
436 ray_session_active: bool,
438}
439
440impl RayIntegration {
441 pub fn new(config: RayConfig) -> Self {
443 Self {
444 config,
445 stats: RayStats::default(),
446 initialized: false,
447 rank: 0,
448 world_size: 1,
449 local_rank: 0,
450 local_size: 1,
451 ray_session_active: false,
452 }
453 }
454
455 pub fn from_file<P: AsRef<Path>>(path: P) -> TorshResult<Self> {
457 let content = std::fs::read_to_string(path).map_err(|e| {
458 TorshDistributedError::configuration_error(format!(
459 "Failed to read Ray config file: {}",
460 e
461 ))
462 })?;
463
464 let config: RayConfig = serde_json::from_str(&content).map_err(|e| {
465 TorshDistributedError::configuration_error(format!("Failed to parse Ray config: {}", e))
466 })?;
467
468 Ok(Self::new(config))
469 }
470
471 pub fn initialize(
473 &mut self,
474 rank: u32,
475 world_size: u32,
476 local_rank: u32,
477 local_size: u32,
478 ) -> TorshResult<()> {
479 if self.initialized {
480 return Err(TorshDistributedError::configuration_error(
481 "Ray integration already initialized",
482 ));
483 }
484
485 self.rank = rank;
486 self.world_size = world_size;
487 self.local_rank = local_rank;
488 self.local_size = local_size;
489
490 self.validate_config()?;
491 self.setup_ray_cluster()?;
492 self.setup_ray_train()?;
493 self.setup_ray_tune()?;
494 self.setup_ray_serve()?;
495 self.setup_ray_data()?;
496 self.setup_fault_tolerance()?;
497
498 self.initialized = true;
499 self.ray_session_active = true;
500
501 tracing::info!(
502 "Ray integration initialized - rank: {}, world_size: {}, local_rank: {}",
503 self.rank,
504 self.world_size,
505 self.local_rank
506 );
507
508 Ok(())
509 }
510
511 fn validate_config(&self) -> TorshResult<()> {
513 if let Some(ref cluster) = self.config.cluster {
515 if let Some(num_cpus) = cluster.num_cpus {
516 if num_cpus == 0 {
517 return Err(TorshDistributedError::configuration_error(
518 "Ray cluster num_cpus must be greater than 0",
519 ));
520 }
521 }
522
523 if let Some(memory_gb) = cluster.memory_gb {
524 if memory_gb <= 0.0 {
525 return Err(TorshDistributedError::configuration_error(
526 "Ray cluster memory_gb must be greater than 0",
527 ));
528 }
529 }
530 }
531
532 if let Some(ref train) = self.config.train {
534 if train.num_workers == 0 {
535 return Err(TorshDistributedError::configuration_error(
536 "Ray Train num_workers must be greater than 0",
537 ));
538 }
539
540 if let Some(ref scaling) = train.scaling_config {
541 if let Some(num_workers) = scaling.num_workers {
542 if num_workers == 0 {
543 return Err(TorshDistributedError::configuration_error(
544 "Ray Train scaling num_workers must be greater than 0",
545 ));
546 }
547 }
548 }
549 }
550
551 if let Some(ref tune) = self.config.tune {
553 if let Some(num_samples) = tune.num_samples {
554 if num_samples == 0 {
555 return Err(TorshDistributedError::configuration_error(
556 "Ray Tune num_samples must be greater than 0",
557 ));
558 }
559 }
560
561 if let Some(max_concurrent) = tune.max_concurrent_trials {
562 if max_concurrent == 0 {
563 return Err(TorshDistributedError::configuration_error(
564 "Ray Tune max_concurrent_trials must be greater than 0",
565 ));
566 }
567 }
568 }
569
570 Ok(())
571 }
572
573 fn setup_ray_cluster(&self) -> TorshResult<()> {
575 if let Some(ref cluster) = self.config.cluster {
576 tracing::info!("Setting up Ray cluster");
577
578 if let Some(ref address) = cluster.address {
579 tracing::debug!("Ray cluster address: {}", address);
580 }
581
582 let num_cpus = cluster.num_cpus.unwrap_or(1);
583 tracing::debug!("Ray cluster CPUs: {}", num_cpus);
584
585 let num_gpus = cluster.num_gpus.unwrap_or(0);
586 tracing::debug!("Ray cluster GPUs: {}", num_gpus);
587
588 let memory_gb = cluster.memory_gb.unwrap_or(4.0);
589 tracing::debug!("Ray cluster memory: {} GB", memory_gb);
590
591 let object_store_memory_gb = cluster.object_store_memory_gb.unwrap_or(2.0);
592 tracing::debug!("Ray object store memory: {} GB", object_store_memory_gb);
593
594 if let Some(ref namespace) = cluster.namespace {
595 tracing::debug!("Ray namespace: {}", namespace);
596 }
597
598 let include_dashboard = cluster.include_dashboard.unwrap_or(true);
599 if include_dashboard {
600 let default_host = "127.0.0.1".to_string();
601 let dashboard_host = cluster.dashboard_host.as_ref().unwrap_or(&default_host);
602 let dashboard_port = cluster.dashboard_port.unwrap_or(8265);
603 tracing::debug!("Ray dashboard: {}:{}", dashboard_host, dashboard_port);
604 }
605 }
606 Ok(())
607 }
608
609 fn setup_ray_train(&self) -> TorshResult<()> {
611 if let Some(ref train) = self.config.train {
612 tracing::info!("Setting up Ray Train");
613
614 tracing::debug!("Ray Train backend: {:?}", train.backend);
615 tracing::debug!("Ray Train workers: {}", train.num_workers);
616
617 let use_gpu = train.use_gpu.unwrap_or(false);
618 tracing::debug!("Ray Train use GPU: {}", use_gpu);
619
620 if let Some(ref resources) = train.resources_per_worker {
621 tracing::debug!("Ray Train resources per worker: {:?}", resources);
622 }
623
624 let placement_strategy = train
625 .placement_group_strategy
626 .unwrap_or(RayPlacementGroupStrategy::Pack);
627 tracing::debug!(
628 "Ray Train placement group strategy: {:?}",
629 placement_strategy
630 );
631
632 if let Some(ref scaling) = train.scaling_config {
633 tracing::debug!("Ray Train scaling configuration: {:?}", scaling);
634 }
635
636 if let Some(ref run_config) = train.run_config {
637 if let Some(ref name) = run_config.name {
638 tracing::debug!("Ray Train experiment name: {}", name);
639 }
640
641 if let Some(ref storage_path) = run_config.storage_path {
642 tracing::debug!("Ray Train storage path: {}", storage_path);
643 }
644 }
645
646 if let Some(ref checkpoint) = train.checkpoint_config {
647 let num_to_keep = checkpoint.num_to_keep.unwrap_or(3);
648 tracing::debug!("Ray Train checkpoints to keep: {}", num_to_keep);
649 }
650
651 if let Some(ref failure) = train.failure_config {
652 let max_failures = failure.max_failures.unwrap_or(3);
653 tracing::debug!("Ray Train max failures: {}", max_failures);
654 }
655 }
656 Ok(())
657 }
658
659 fn setup_ray_tune(&self) -> TorshResult<()> {
661 if let Some(ref tune) = self.config.tune {
662 tracing::info!("Setting up Ray Tune");
663
664 if let Some(search_alg) = tune.search_alg {
665 tracing::debug!("Ray Tune search algorithm: {:?}", search_alg);
666 }
667
668 if let Some(scheduler) = tune.scheduler {
669 tracing::debug!("Ray Tune scheduler: {:?}", scheduler);
670 }
671
672 let num_samples = tune.num_samples.unwrap_or(10);
673 tracing::debug!("Ray Tune samples: {}", num_samples);
674
675 let max_concurrent = tune.max_concurrent_trials.unwrap_or(4);
676 tracing::debug!("Ray Tune max concurrent trials: {}", max_concurrent);
677
678 if let Some(ref resources) = tune.resources_per_trial {
679 tracing::debug!("Ray Tune resources per trial: {:?}", resources);
680 }
681
682 if let Some(ref metric) = tune.metric {
683 tracing::debug!("Ray Tune optimization metric: {}", metric);
684 }
685
686 if let Some(ref mode) = tune.mode {
687 tracing::debug!("Ray Tune optimization mode: {}", mode);
688 }
689
690 if let Some(time_budget) = tune.time_budget_s {
691 tracing::debug!("Ray Tune time budget: {} seconds", time_budget);
692 }
693 }
694 Ok(())
695 }
696
697 fn setup_ray_serve(&self) -> TorshResult<()> {
699 if let Some(ref serve) = self.config.serve {
700 tracing::info!("Setting up Ray Serve");
701
702 if let Some(ref http) = serve.http_options {
703 let default_host = "127.0.0.1".to_string();
704 let host = http.host.as_ref().unwrap_or(&default_host);
705 let port = http.port.unwrap_or(8000);
706 tracing::debug!("Ray Serve HTTP: {}:{}", host, port);
707
708 if let Some(ref root_path) = http.root_path {
709 tracing::debug!("Ray Serve HTTP root path: {}", root_path);
710 }
711 }
712
713 if let Some(ref grpc) = serve.grpc_options {
714 let port = grpc.port.unwrap_or(9000);
715 tracing::debug!("Ray Serve gRPC port: {}", port);
716
717 if let Some(ref functions) = grpc.grpc_servicer_functions {
718 tracing::debug!("Ray Serve gRPC servicer functions: {:?}", functions);
719 }
720 }
721
722 if let Some(ref deployments) = serve.deployments {
723 for deployment in deployments {
724 tracing::debug!("Ray Serve deployment: {}", deployment.name);
725
726 let num_replicas = deployment.num_replicas.unwrap_or(1);
727 tracing::debug!(" Replicas: {}", num_replicas);
728
729 if let Some(ref autoscaling) = deployment.autoscaling_config {
730 let min_replicas = autoscaling.min_replicas.unwrap_or(1);
731 let max_replicas = autoscaling.max_replicas.unwrap_or(10);
732 tracing::debug!(" Autoscaling: {} - {}", min_replicas, max_replicas);
733 }
734 }
735 }
736 }
737 Ok(())
738 }
739
740 fn setup_ray_data(&self) -> TorshResult<()> {
742 if let Some(ref data) = self.config.data {
743 tracing::info!("Setting up Ray Data");
744
745 if let Some(format) = data.format {
746 tracing::debug!("Ray Data format: {:?}", format);
747 }
748
749 let parallelism = data.parallelism.unwrap_or(4);
750 tracing::debug!("Ray Data parallelism: {}", parallelism);
751
752 let batch_size = data.batch_size.unwrap_or(32);
753 tracing::debug!("Ray Data batch size: {}", batch_size);
754
755 let prefetch = data.prefetch.unwrap_or(2);
756 tracing::debug!("Ray Data prefetch: {}", prefetch);
757
758 let shuffle = data.shuffle.unwrap_or(false);
759 tracing::debug!("Ray Data shuffle: {}", shuffle);
760
761 if shuffle {
762 let shuffle_buffer_size = data.shuffle_buffer_size.unwrap_or(1000);
763 tracing::debug!("Ray Data shuffle buffer size: {}", shuffle_buffer_size);
764 }
765 }
766 Ok(())
767 }
768
769 fn setup_fault_tolerance(&self) -> TorshResult<()> {
771 if let Some(ref fault_tolerance) = self.config.fault_tolerance {
772 tracing::info!("Setting up Ray fault tolerance");
773
774 let enabled = fault_tolerance.enabled.unwrap_or(true);
775 tracing::debug!("Ray fault tolerance enabled: {}", enabled);
776
777 if enabled {
778 let max_restarts = fault_tolerance.max_restarts.unwrap_or(3);
779 tracing::debug!("Ray max restarts: {}", max_restarts);
780
781 let restart_delay = fault_tolerance.restart_delay_s.unwrap_or(5.0);
782 tracing::debug!("Ray restart delay: {} seconds", restart_delay);
783
784 let health_check_interval = fault_tolerance.health_check_interval_s.unwrap_or(10.0);
785 tracing::debug!(
786 "Ray health check interval: {} seconds",
787 health_check_interval
788 );
789 }
790 }
791 Ok(())
792 }
793
794 pub fn to_elastic_config(&self) -> TorshResult<Option<crate::fault_tolerance::ElasticConfig>> {
796 if let Some(ref train) = self.config.train {
797 use crate::fault_tolerance::ElasticConfig;
798
799 let min_workers = if let Some(ref scaling) = train.scaling_config {
800 scaling.num_workers.unwrap_or(train.num_workers)
801 } else {
802 train.num_workers
803 };
804
805 let max_workers = min_workers * 2; let config = ElasticConfig {
808 min_workers: min_workers as usize,
809 max_workers: max_workers as usize,
810 scaling_timeout: std::time::Duration::from_secs(300),
811 scaling_check_interval: std::time::Duration::from_secs(30),
812 enable_elastic_scheduling: true,
813 rendezvous_backend: "etcd".to_string(),
814 rendezvous_endpoint: "localhost:2379".to_string(),
815 };
816
817 Ok(Some(config))
818 } else {
819 Ok(None)
820 }
821 }
822
823 pub fn config(&self) -> &RayConfig {
825 &self.config
826 }
827
828 pub fn stats(&self) -> &RayStats {
830 &self.stats
831 }
832
833 pub fn is_initialized(&self) -> bool {
835 self.initialized
836 }
837
838 pub fn rank(&self) -> u32 {
840 self.rank
841 }
842
843 pub fn world_size(&self) -> u32 {
845 self.world_size
846 }
847
848 pub fn local_rank(&self) -> u32 {
850 self.local_rank
851 }
852
853 pub fn local_size(&self) -> u32 {
855 self.local_size
856 }
857
858 pub fn is_ray_session_active(&self) -> bool {
860 self.ray_session_active
861 }
862
863 pub fn run_training(&mut self, train_func_name: &str, num_epochs: u32) -> TorshResult<()> {
865 if !self.initialized {
866 return Err(TorshDistributedError::BackendNotInitialized);
867 }
868
869 let start_time = std::time::Instant::now();
870
871 tracing::info!(
872 "Running Ray Train: {} for {} epochs",
873 train_func_name,
874 num_epochs
875 );
876
877 for epoch in 1..=num_epochs {
879 tracing::debug!("Ray Train epoch {}/{}", epoch, num_epochs);
880
881 if epoch % 10 == 0 && self.config.fault_tolerance.is_some() {
883 self.handle_worker_failure()?;
884 }
885 }
886
887 self.stats.training_runs += 1;
889 self.stats.training_time_sec += start_time.elapsed().as_secs_f64();
890
891 tracing::info!("Ray Train completed: {}", train_func_name);
892 Ok(())
893 }
894
895 pub fn run_tuning(&mut self, tune_config_name: &str) -> TorshResult<()> {
897 if !self.initialized {
898 return Err(TorshDistributedError::BackendNotInitialized);
899 }
900
901 let start_time = std::time::Instant::now();
902
903 let num_trials = self
904 .config
905 .tune
906 .as_ref()
907 .and_then(|t| t.num_samples)
908 .unwrap_or(10);
909
910 tracing::info!(
911 "Running Ray Tune: {} with {} trials",
912 tune_config_name,
913 num_trials
914 );
915
916 for trial in 1..=num_trials {
918 tracing::debug!("Ray Tune trial {}/{}", trial, num_trials);
919 self.stats.tuning_trials += 1;
920 }
921
922 self.stats.tuning_time_sec += start_time.elapsed().as_secs_f64();
924
925 tracing::info!("Ray Tune completed: {}", tune_config_name);
926 Ok(())
927 }
928
929 fn handle_worker_failure(&mut self) -> TorshResult<()> {
931 tracing::warn!("Simulating Ray worker failure");
932 self.stats.worker_failures += 1;
933
934 if let Some(ref fault_tolerance) = self.config.fault_tolerance {
935 if fault_tolerance.enabled.unwrap_or(true) {
936 let max_restarts = fault_tolerance.max_restarts.unwrap_or(3);
937
938 if self.stats.restarts < max_restarts as u64 {
939 tracing::info!("Restarting failed Ray worker");
940 self.stats.restarts += 1;
941
942 let restart_delay = fault_tolerance.restart_delay_s.unwrap_or(5.0);
943 tracing::debug!("Ray restart delay: {} seconds", restart_delay);
944 } else {
945 return Err(TorshDistributedError::process_failure(
946 self.rank,
947 "ray_worker",
948 "Maximum restart attempts exceeded",
949 ));
950 }
951 }
952 }
953
954 Ok(())
955 }
956
957 pub fn shutdown(&mut self) -> TorshResult<()> {
959 if self.ray_session_active {
960 tracing::info!("Shutting down Ray integration");
961 self.ray_session_active = false;
962 self.initialized = false;
963 }
964 Ok(())
965 }
966
967 pub fn default_config() -> RayConfig {
969 RayConfig {
970 cluster: Some(RayClusterConfig {
971 address: None,
972 redis_address: None,
973 num_cpus: Some(4),
974 num_gpus: Some(0),
975 memory_gb: Some(8.0),
976 object_store_memory_gb: Some(2.0),
977 namespace: None,
978 dashboard_host: Some("127.0.0.1".to_string()),
979 dashboard_port: Some(8265),
980 include_dashboard: Some(true),
981 }),
982 train: Some(RayTrainConfig {
983 backend: RayTrainBackend::Torch,
984 num_workers: 4,
985 use_gpu: Some(false),
986 resources_per_worker: None,
987 placement_group_strategy: Some(RayPlacementGroupStrategy::Pack),
988 scaling_config: None,
989 run_config: None,
990 checkpoint_config: None,
991 failure_config: Some(RayFailureConfig {
992 max_failures: Some(3),
993 failure_handling: Some(RayFailureHandling::Restart),
994 }),
995 }),
996 tune: None,
997 serve: None,
998 data: Some(RayDataConfig {
999 format: Some(RayDataFormat::Parquet),
1000 parallelism: Some(4),
1001 batch_size: Some(32),
1002 prefetch: Some(2),
1003 shuffle: Some(false),
1004 shuffle_buffer_size: Some(1000),
1005 }),
1006 resources: Some(RayResourceConfig {
1007 num_cpus: Some(4.0),
1008 num_gpus: Some(0.0),
1009 memory: Some(8 * 1024 * 1024 * 1024), object_store_memory: Some(2 * 1024 * 1024 * 1024), custom_resources: None,
1012 }),
1013 fault_tolerance: Some(RayFaultToleranceConfig {
1014 max_restarts: Some(3),
1015 restart_delay_s: Some(5.0),
1016 health_check_interval_s: Some(10.0),
1017 enabled: Some(true),
1018 }),
1019 }
1020 }
1021
1022 pub fn config_with_tune(num_samples: u32, search_alg: RaySearchAlgorithm) -> RayConfig {
1024 let mut config = Self::default_config();
1025
1026 config.tune = Some(RayTuneConfig {
1027 search_alg: Some(search_alg),
1028 scheduler: Some(RayScheduler::ASHA),
1029 num_samples: Some(num_samples),
1030 max_concurrent_trials: Some(4),
1031 resources_per_trial: Some([("cpu".to_string(), 1.0)].into_iter().collect()),
1032 param_space: None,
1033 metric: Some("accuracy".to_string()),
1034 mode: Some("max".to_string()),
1035 time_budget_s: Some(3600.0), });
1037
1038 config
1039 }
1040
1041 pub fn config_with_serve(num_replicas: u32) -> RayConfig {
1043 let mut config = Self::default_config();
1044
1045 config.serve = Some(RayServeConfig {
1046 http_options: Some(RayServeHttpOptions {
1047 host: Some("0.0.0.0".to_string()),
1048 port: Some(8000),
1049 root_path: None,
1050 }),
1051 grpc_options: None,
1052 deployments: Some(vec![RayServeDeploymentConfig {
1053 name: "model".to_string(),
1054 num_replicas: Some(num_replicas),
1055 ray_actor_options: Some(
1056 [(
1057 "num_cpus".to_string(),
1058 serde_json::Value::Number(serde_json::Number::from(1)),
1059 )]
1060 .into_iter()
1061 .collect(),
1062 ),
1063 user_config: None,
1064 max_concurrent_queries: Some(100),
1065 autoscaling_config: Some(RayServeAutoscalingConfig {
1066 min_replicas: Some(1),
1067 max_replicas: Some(num_replicas * 2),
1068 target_num_ongoing_requests_per_replica: Some(10.0),
1069 metrics_interval_s: Some(10.0),
1070 look_back_period_s: Some(30.0),
1071 smoothing_factor: Some(1.0),
1072 }),
1073 }]),
1074 });
1075
1076 config
1077 }
1078}
1079
1080impl Default for RayConfig {
1081 fn default() -> Self {
1082 RayIntegration::default_config()
1083 }
1084}
1085
1086#[cfg(test)]
1087mod tests {
1088 use super::*;
1089
1090 #[test]
1091 fn test_ray_config_validation() {
1092 let config = RayIntegration::default_config();
1093 let mut integration = RayIntegration::new(config);
1094
1095 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1097 assert!(integration.is_initialized());
1098 assert!(integration.is_ray_session_active());
1099 assert_eq!(integration.rank(), 0);
1100 assert_eq!(integration.world_size(), 4);
1101 assert_eq!(integration.local_rank(), 0);
1102 }
1103
1104 #[test]
1105 fn test_ray_training_simulation() {
1106 let config = RayIntegration::default_config();
1107 let mut integration = RayIntegration::new(config);
1108
1109 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1110
1111 assert!(integration.run_training("my_train_func", 5).is_ok());
1113 assert!(integration.run_training("another_train_func", 3).is_ok());
1114
1115 let stats = integration.stats();
1116 assert_eq!(stats.training_runs, 2);
1117 assert!(stats.training_time_sec >= 0.0); }
1119
1120 #[test]
1121 fn test_ray_tuning_simulation() {
1122 let config = RayIntegration::config_with_tune(20, RaySearchAlgorithm::BayesOpt);
1123 let mut integration = RayIntegration::new(config);
1124
1125 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1126
1127 assert!(integration.run_tuning("hyperparameter_search").is_ok());
1129
1130 let stats = integration.stats();
1131 assert_eq!(stats.tuning_trials, 20);
1132 assert!(stats.tuning_time_sec > 0.0);
1133 }
1134
1135 #[test]
1136 fn test_ray_elastic_config_conversion() {
1137 let config = RayIntegration::default_config();
1138 let mut integration = RayIntegration::new(config);
1139
1140 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1141
1142 let elastic_config = integration.to_elastic_config().unwrap();
1144 assert!(elastic_config.is_some());
1145
1146 if let Some(config) = elastic_config {
1147 assert_eq!(config.min_workers, 4);
1148 assert_eq!(config.max_workers, 8);
1149 assert!(config.enable_elastic_scheduling);
1150 assert_eq!(config.rendezvous_backend, "etcd");
1151 }
1152 }
1153
1154 #[test]
1155 fn test_ray_worker_failure_handling() {
1156 let config = RayIntegration::default_config();
1157 let mut integration = RayIntegration::new(config);
1158
1159 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1160
1161 assert!(integration.handle_worker_failure().is_ok());
1163 assert!(integration.handle_worker_failure().is_ok());
1164 assert!(integration.handle_worker_failure().is_ok());
1165
1166 let stats = integration.stats();
1167 assert_eq!(stats.worker_failures, 3);
1168 assert_eq!(stats.restarts, 3);
1169
1170 assert!(integration.handle_worker_failure().is_err());
1172 }
1173
1174 #[test]
1175 fn test_ray_shutdown() {
1176 let config = RayIntegration::default_config();
1177 let mut integration = RayIntegration::new(config);
1178
1179 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1180 assert!(integration.is_ray_session_active());
1181
1182 assert!(integration.shutdown().is_ok());
1183 assert!(!integration.is_ray_session_active());
1184 assert!(!integration.is_initialized());
1185 }
1186
1187 #[test]
1188 fn test_ray_serve_config() {
1189 let config = RayIntegration::config_with_serve(4);
1190 let mut integration = RayIntegration::new(config);
1191
1192 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1193
1194 assert!(integration.config().serve.is_some());
1196
1197 if let Some(ref serve) = integration.config().serve {
1198 assert!(serve.http_options.is_some());
1199 assert!(serve.deployments.is_some());
1200
1201 if let Some(ref deployments) = serve.deployments {
1202 assert_eq!(deployments.len(), 1);
1203 assert_eq!(deployments[0].name, "model");
1204 assert_eq!(deployments[0].num_replicas, Some(4));
1205 }
1206 }
1207 }
1208
1209 #[test]
1210 fn test_ray_config_serialization() {
1211 let config = RayIntegration::config_with_tune(10, RaySearchAlgorithm::Random);
1212
1213 let json = serde_json::to_string(&config).unwrap();
1215 assert!(json.contains("Random"));
1216 assert!(json.contains("ASHA"));
1217 assert!(json.contains("accuracy"));
1218
1219 let deserialized: RayConfig = serde_json::from_str(&json).unwrap();
1221 assert!(deserialized.tune.is_some());
1222
1223 if let Some(tune) = deserialized.tune {
1224 assert_eq!(tune.search_alg, Some(RaySearchAlgorithm::Random));
1225 assert_eq!(tune.scheduler, Some(RayScheduler::ASHA));
1226 assert_eq!(tune.num_samples, Some(10));
1227 }
1228 }
1229
1230 #[test]
1231 fn test_ray_invalid_config() {
1232 let mut config = RayIntegration::default_config();
1233
1234 if let Some(ref mut train) = config.train {
1236 train.num_workers = 0; }
1238
1239 let mut integration = RayIntegration::new(config);
1240
1241 assert!(integration.initialize(0, 4, 0, 2).is_err());
1243 }
1244}