1use crate::{TorshDistributedError, TorshResult};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::path::Path;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct DaskConfig {
23 pub cluster: Option<DaskClusterConfig>,
25 pub scheduler: Option<DaskSchedulerConfig>,
27 pub worker: Option<DaskWorkerConfig>,
29 pub array: Option<DaskArrayConfig>,
31 pub dataframe: Option<DaskDataFrameConfig>,
33 pub bag: Option<DaskBagConfig>,
35 pub ml: Option<DaskMLConfig>,
37 pub distributed: Option<DaskDistributedConfig>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct DaskClusterConfig {
44 pub cluster_type: DaskClusterType,
46 pub n_workers: Option<u32>,
48 pub threads_per_worker: Option<u32>,
50 pub memory_limit: Option<String>,
52 pub processes: Option<bool>,
54 pub dashboard_address: Option<String>,
56 pub silence_logs: Option<bool>,
58 pub security: Option<DaskSecurityConfig>,
60 pub scaling: Option<DaskScalingConfig>,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66pub enum DaskClusterType {
67 Local,
69 LocalProcess,
71 Kubernetes,
73 Slurm,
75 PBS,
77 SGE,
79 Custom,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct DaskSecurityConfig {
86 pub tls_cert: Option<String>,
88 pub tls_key: Option<String>,
90 pub tls_ca_file: Option<String>,
92 pub require_encryption: Option<bool>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct DaskScalingConfig {
99 pub minimum: Option<u32>,
101 pub maximum: Option<u32>,
103 pub target_cpu: Option<f32>,
105 pub target_memory: Option<f32>,
107 pub scale_up_threshold: Option<f32>,
109 pub scale_down_threshold: Option<f32>,
111 pub interval: Option<f32>,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct DaskSchedulerConfig {
118 pub address: Option<String>,
120 pub port: Option<u16>,
122 pub dashboard_port: Option<u16>,
124 pub bokeh_port: Option<u16>,
126 pub worker_timeout: Option<f32>,
128 pub idle_timeout: Option<f32>,
130 pub transition_log_length: Option<u32>,
132 pub task_duration_overhead: Option<f32>,
134 pub allowed_failures: Option<u32>,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct DaskWorkerConfig {
141 pub nworkers: Option<u32>,
143 pub nthreads: Option<u32>,
145 pub memory_limit: Option<String>,
147 pub worker_port: Option<String>,
149 pub nanny_port: Option<String>,
151 pub dashboard_port: Option<u16>,
153 pub death_timeout: Option<f32>,
155 pub preload: Option<Vec<String>>,
157 pub env: Option<HashMap<String, String>>,
159 pub resources: Option<HashMap<String, f32>>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct DaskArrayConfig {
166 pub chunk_size: Option<String>,
168 pub backend: Option<String>,
170 pub overlap: Option<u32>,
172 pub boundary: Option<String>,
174 pub trim: Option<bool>,
176 pub rechunk_threshold: Option<f32>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct DaskDataFrameConfig {
183 pub partition_size: Option<String>,
185 pub shuffle_method: Option<DaskShuffleMethod>,
187 pub query_planning: Option<bool>,
189 pub backend: Option<String>,
191 pub optimize_index: Option<bool>,
193}
194
195#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
197pub enum DaskShuffleMethod {
198 Disk,
200 Tasks,
202 P2P,
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct DaskBagConfig {
209 pub partition_size: Option<u64>,
211 pub compression: Option<String>,
213 pub encoding: Option<String>,
215 pub linedelimiter: Option<String>,
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct DaskMLConfig {
222 pub model_selection: Option<DaskMLModelSelectionConfig>,
224 pub preprocessing: Option<DaskMLPreprocessingConfig>,
226 pub linear_model: Option<DaskMLLinearModelConfig>,
228 pub ensemble: Option<DaskMLEnsembleConfig>,
230 pub cluster: Option<DaskMLClusterConfig>,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct DaskMLModelSelectionConfig {
237 pub cv_folds: Option<u32>,
239 pub scoring: Option<String>,
241 pub n_jobs: Option<i32>,
243 pub search_method: Option<DaskMLSearchMethod>,
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
249pub enum DaskMLSearchMethod {
250 GridSearch,
252 RandomSearch,
254 SuccessiveHalving,
256 Hyperband,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct DaskMLPreprocessingConfig {
263 pub standardization: Option<String>,
265 pub categorical_encoding: Option<String>,
267 pub feature_selection: Option<String>,
269 pub dimensionality_reduction: Option<String>,
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct DaskMLLinearModelConfig {
276 pub solver: Option<String>,
278 pub alpha: Option<f32>,
280 pub max_iter: Option<u32>,
282 pub tol: Option<f32>,
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct DaskMLEnsembleConfig {
289 pub n_estimators: Option<u32>,
291 pub bootstrap: Option<bool>,
293 pub random_state: Option<u32>,
295 pub oob_score: Option<bool>,
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct DaskMLClusterConfig {
302 pub n_clusters: Option<u32>,
304 pub init: Option<String>,
306 pub max_iter: Option<u32>,
308 pub tol: Option<f32>,
310 pub n_init: Option<u32>,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct DaskDistributedConfig {
317 pub comm: Option<DaskCommConfig>,
319 pub serialization: Option<DaskSerializationConfig>,
321 pub client: Option<DaskClientConfig>,
323 pub scheduling: Option<DaskSchedulingConfig>,
325 pub diagnostics: Option<DaskDiagnosticsConfig>,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct DaskCommConfig {
332 pub compression: Option<String>,
334 pub serializers: Option<Vec<String>>,
336 pub timeouts: Option<DaskTimeoutsConfig>,
338 pub tcp: Option<DaskTcpConfig>,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct DaskTimeoutsConfig {
345 pub connect: Option<f32>,
347 pub tcp: Option<f32>,
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
353pub struct DaskTcpConfig {
354 pub reuse_port: Option<bool>,
356 pub no_delay: Option<bool>,
358 pub keep_alive: Option<bool>,
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
364pub struct DaskSerializationConfig {
365 pub compression: Option<Vec<String>>,
367 pub default_serializers: Option<Vec<String>>,
369 pub pickle_protocol: Option<u32>,
371}
372
373#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct DaskClientConfig {
376 pub heartbeat_interval: Option<f32>,
378 pub scheduler_info_interval: Option<f32>,
380 pub task_metadata: Option<Vec<String>>,
382 pub set_as_default: Option<bool>,
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct DaskSchedulingConfig {
389 pub work_stealing: Option<bool>,
391 pub work_stealing_interval: Option<f32>,
393 pub unknown_task_duration: Option<f32>,
395 pub default_task_durations: Option<HashMap<String, f32>>,
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct DaskDiagnosticsConfig {
402 pub progress_bar: Option<bool>,
404 pub profile: Option<bool>,
406 pub memory_profiling: Option<bool>,
408 pub task_stream: Option<bool>,
410 pub resource_monitor: Option<bool>,
412}
413
414#[derive(Debug, Clone, Default)]
416pub struct DaskStats {
417 pub tasks_executed: u64,
419 pub task_execution_time_sec: f64,
421 pub workers_connected: u32,
423 pub data_transferred_bytes: u64,
425 pub task_retries: u64,
427 pub worker_failures: u64,
429 pub memory_usage_bytes: u64,
431 pub cpu_utilization: f64,
433 pub network_bandwidth_bytes_per_sec: f64,
435 pub average_task_duration_sec: f64,
437}
438
439pub struct DaskIntegration {
441 config: DaskConfig,
443 stats: DaskStats,
445 initialized: bool,
447 rank: u32,
449 world_size: u32,
451 local_rank: u32,
453 local_size: u32,
455 client_active: bool,
457}
458
459impl DaskIntegration {
460 pub fn new(config: DaskConfig) -> Self {
462 Self {
463 config,
464 stats: DaskStats::default(),
465 initialized: false,
466 rank: 0,
467 world_size: 1,
468 local_rank: 0,
469 local_size: 1,
470 client_active: false,
471 }
472 }
473
474 pub fn from_file<P: AsRef<Path>>(path: P) -> TorshResult<Self> {
476 let content = std::fs::read_to_string(path).map_err(|e| {
477 TorshDistributedError::configuration_error(format!(
478 "Failed to read Dask config file: {}",
479 e
480 ))
481 })?;
482
483 let config: DaskConfig = serde_json::from_str(&content).map_err(|e| {
484 TorshDistributedError::configuration_error(format!(
485 "Failed to parse Dask config: {}",
486 e
487 ))
488 })?;
489
490 Ok(Self::new(config))
491 }
492
493 pub fn initialize(
495 &mut self,
496 rank: u32,
497 world_size: u32,
498 local_rank: u32,
499 local_size: u32,
500 ) -> TorshResult<()> {
501 if self.initialized {
502 return Err(TorshDistributedError::configuration_error(
503 "Dask integration already initialized",
504 ));
505 }
506
507 self.rank = rank;
508 self.world_size = world_size;
509 self.local_rank = local_rank;
510 self.local_size = local_size;
511
512 self.validate_config()?;
513 self.setup_cluster()?;
514 self.setup_scheduler()?;
515 self.setup_workers()?;
516 self.setup_client()?;
517 self.setup_ml()?;
518 self.setup_distributed()?;
519
520 self.initialized = true;
521 self.client_active = true;
522
523 tracing::info!(
524 "Dask integration initialized - rank: {}, world_size: {}, local_rank: {}",
525 self.rank,
526 self.world_size,
527 self.local_rank
528 );
529
530 Ok(())
531 }
532
533 fn validate_config(&self) -> TorshResult<()> {
535 if let Some(ref cluster) = self.config.cluster {
537 if let Some(n_workers) = cluster.n_workers {
538 if n_workers == 0 {
539 return Err(TorshDistributedError::configuration_error(
540 "Dask cluster n_workers must be greater than 0",
541 ));
542 }
543 }
544
545 if let Some(threads_per_worker) = cluster.threads_per_worker {
546 if threads_per_worker == 0 {
547 return Err(TorshDistributedError::configuration_error(
548 "Dask cluster threads_per_worker must be greater than 0",
549 ));
550 }
551 }
552
553 if let Some(ref scaling) = cluster.scaling {
554 if let Some(minimum) = scaling.minimum {
555 if let Some(maximum) = scaling.maximum {
556 if minimum > maximum {
557 return Err(TorshDistributedError::configuration_error(
558 "Dask scaling minimum workers cannot exceed maximum workers",
559 ));
560 }
561 }
562 }
563 }
564 }
565
566 if let Some(ref scheduler) = self.config.scheduler {
568 if let Some(port) = scheduler.port {
569 if port == 0 {
570 return Err(TorshDistributedError::configuration_error(
571 "Dask scheduler port must be greater than 0",
572 ));
573 }
574 }
575
576 if let Some(worker_timeout) = scheduler.worker_timeout {
577 if worker_timeout <= 0.0 {
578 return Err(TorshDistributedError::configuration_error(
579 "Dask scheduler worker_timeout must be greater than 0",
580 ));
581 }
582 }
583 }
584
585 if let Some(ref worker) = self.config.worker {
587 if let Some(nworkers) = worker.nworkers {
588 if nworkers == 0 {
589 return Err(TorshDistributedError::configuration_error(
590 "Dask worker nworkers must be greater than 0",
591 ));
592 }
593 }
594
595 if let Some(nthreads) = worker.nthreads {
596 if nthreads == 0 {
597 return Err(TorshDistributedError::configuration_error(
598 "Dask worker nthreads must be greater than 0",
599 ));
600 }
601 }
602 }
603
604 Ok(())
605 }
606
607 fn setup_cluster(&self) -> TorshResult<()> {
609 if let Some(ref cluster) = self.config.cluster {
610 tracing::info!("Setting up Dask cluster: {:?}", cluster.cluster_type);
611
612 let n_workers = cluster.n_workers.unwrap_or(4);
613 tracing::debug!("Dask cluster workers: {}", n_workers);
614
615 let threads_per_worker = cluster.threads_per_worker.unwrap_or(2);
616 tracing::debug!("Dask threads per worker: {}", threads_per_worker);
617
618 if let Some(ref memory_limit) = cluster.memory_limit {
619 tracing::debug!("Dask memory limit per worker: {}", memory_limit);
620 }
621
622 let processes = cluster.processes.unwrap_or(false);
623 tracing::debug!("Dask use processes: {}", processes);
624
625 if let Some(ref dashboard_address) = cluster.dashboard_address {
626 tracing::debug!("Dask dashboard address: {}", dashboard_address);
627 }
628
629 let silence_logs = cluster.silence_logs.unwrap_or(false);
630 tracing::debug!("Dask silence logs: {}", silence_logs);
631
632 if let Some(ref security) = cluster.security {
633 tracing::debug!("Dask security enabled");
634 if let Some(ref tls_cert) = security.tls_cert {
635 tracing::debug!("Dask TLS cert: {}", tls_cert);
636 }
637 if let Some(ref tls_key) = security.tls_key {
638 tracing::debug!("Dask TLS key: {}", tls_key);
639 }
640 let require_encryption = security.require_encryption.unwrap_or(false);
641 tracing::debug!("Dask require encryption: {}", require_encryption);
642 }
643
644 if let Some(ref scaling) = cluster.scaling {
645 let minimum = scaling.minimum.unwrap_or(1);
646 let maximum = scaling.maximum.unwrap_or(n_workers * 2);
647 tracing::debug!("Dask scaling: {} - {} workers", minimum, maximum);
648
649 let target_cpu = scaling.target_cpu.unwrap_or(0.8);
650 tracing::debug!("Dask target CPU utilization: {}", target_cpu);
651
652 let interval = scaling.interval.unwrap_or(30.0);
653 tracing::debug!("Dask scaling interval: {} seconds", interval);
654 }
655 }
656 Ok(())
657 }
658
659 fn setup_scheduler(&self) -> TorshResult<()> {
661 if let Some(ref scheduler) = self.config.scheduler {
662 tracing::info!("Setting up Dask scheduler");
663
664 if let Some(ref address) = scheduler.address {
665 tracing::debug!("Dask scheduler address: {}", address);
666 }
667
668 let port = scheduler.port.unwrap_or(8786);
669 tracing::debug!("Dask scheduler port: {}", port);
670
671 let dashboard_port = scheduler.dashboard_port.unwrap_or(8787);
672 tracing::debug!("Dask scheduler dashboard port: {}", dashboard_port);
673
674 let worker_timeout = scheduler.worker_timeout.unwrap_or(60.0);
675 tracing::debug!("Dask scheduler worker timeout: {} seconds", worker_timeout);
676
677 let idle_timeout = scheduler.idle_timeout.unwrap_or(1800.0);
678 tracing::debug!("Dask scheduler idle timeout: {} seconds", idle_timeout);
679
680 let allowed_failures = scheduler.allowed_failures.unwrap_or(3);
681 tracing::debug!("Dask scheduler allowed failures: {}", allowed_failures);
682 }
683 Ok(())
684 }
685
686 fn setup_workers(&mut self) -> TorshResult<()> {
688 if let Some(ref worker) = self.config.worker {
689 tracing::info!("Setting up Dask workers");
690
691 let nworkers = worker.nworkers.unwrap_or(4);
692 tracing::debug!("Dask number of workers: {}", nworkers);
693
694 let nthreads = worker.nthreads.unwrap_or(2);
695 tracing::debug!("Dask threads per worker: {}", nthreads);
696
697 if let Some(ref memory_limit) = worker.memory_limit {
698 tracing::debug!("Dask worker memory limit: {}", memory_limit);
699 }
700
701 if let Some(ref worker_port) = worker.worker_port {
702 tracing::debug!("Dask worker port range: {}", worker_port);
703 }
704
705 if let Some(ref nanny_port) = worker.nanny_port {
706 tracing::debug!("Dask nanny port range: {}", nanny_port);
707 }
708
709 let death_timeout = worker.death_timeout.unwrap_or(60.0);
710 tracing::debug!("Dask worker death timeout: {} seconds", death_timeout);
711
712 if let Some(ref preload) = worker.preload {
713 tracing::debug!("Dask worker preload modules: {:?}", preload);
714 }
715
716 if let Some(ref env) = worker.env {
717 tracing::debug!("Dask worker environment variables: {:?}", env);
718 }
719
720 if let Some(ref resources) = worker.resources {
721 tracing::debug!("Dask worker resources: {:?}", resources);
722 }
723 }
724
725 self.stats.workers_connected = self
727 .config
728 .worker
729 .as_ref()
730 .and_then(|w| w.nworkers)
731 .unwrap_or(4);
732
733 Ok(())
734 }
735
736 fn setup_client(&self) -> TorshResult<()> {
738 if let Some(ref distributed) = self.config.distributed {
739 if let Some(ref client) = distributed.client {
740 tracing::info!("Setting up Dask client");
741
742 let heartbeat_interval = client.heartbeat_interval.unwrap_or(5.0);
743 tracing::debug!(
744 "Dask client heartbeat interval: {} seconds",
745 heartbeat_interval
746 );
747
748 let scheduler_info_interval = client.scheduler_info_interval.unwrap_or(2.0);
749 tracing::debug!(
750 "Dask client scheduler info interval: {} seconds",
751 scheduler_info_interval
752 );
753
754 if let Some(ref task_metadata) = client.task_metadata {
755 tracing::debug!("Dask client task metadata: {:?}", task_metadata);
756 }
757
758 let set_as_default = client.set_as_default.unwrap_or(true);
759 tracing::debug!("Dask client set as default: {}", set_as_default);
760 }
761 }
762 Ok(())
763 }
764
765 fn setup_ml(&self) -> TorshResult<()> {
767 if let Some(ref ml) = self.config.ml {
768 tracing::info!("Setting up Dask ML");
769
770 if let Some(ref model_selection) = ml.model_selection {
771 let cv_folds = model_selection.cv_folds.unwrap_or(5);
772 tracing::debug!("Dask ML cross-validation folds: {}", cv_folds);
773
774 if let Some(ref scoring) = model_selection.scoring {
775 tracing::debug!("Dask ML scoring metric: {}", scoring);
776 }
777
778 if let Some(search_method) = model_selection.search_method {
779 tracing::debug!("Dask ML search method: {:?}", search_method);
780 }
781 }
782
783 if let Some(ref preprocessing) = ml.preprocessing {
784 if let Some(ref standardization) = preprocessing.standardization {
785 tracing::debug!("Dask ML standardization: {}", standardization);
786 }
787
788 if let Some(ref encoding) = preprocessing.categorical_encoding {
789 tracing::debug!("Dask ML categorical encoding: {}", encoding);
790 }
791 }
792
793 if let Some(ref linear_model) = ml.linear_model {
794 if let Some(ref solver) = linear_model.solver {
795 tracing::debug!("Dask ML linear model solver: {}", solver);
796 }
797
798 let max_iter = linear_model.max_iter.unwrap_or(1000);
799 tracing::debug!("Dask ML linear model max iterations: {}", max_iter);
800 }
801
802 if let Some(ref ensemble) = ml.ensemble {
803 let n_estimators = ensemble.n_estimators.unwrap_or(100);
804 tracing::debug!("Dask ML ensemble estimators: {}", n_estimators);
805
806 let bootstrap = ensemble.bootstrap.unwrap_or(true);
807 tracing::debug!("Dask ML ensemble bootstrap: {}", bootstrap);
808 }
809
810 if let Some(ref cluster) = ml.cluster {
811 let n_clusters = cluster.n_clusters.unwrap_or(8);
812 tracing::debug!("Dask ML clustering clusters: {}", n_clusters);
813
814 let max_iter = cluster.max_iter.unwrap_or(300);
815 tracing::debug!("Dask ML clustering max iterations: {}", max_iter);
816 }
817 }
818 Ok(())
819 }
820
821 fn setup_distributed(&self) -> TorshResult<()> {
823 if let Some(ref distributed) = self.config.distributed {
824 tracing::info!("Setting up Dask distributed");
825
826 if let Some(ref comm) = distributed.comm {
827 if let Some(ref compression) = comm.compression {
828 tracing::debug!("Dask communication compression: {}", compression);
829 }
830
831 if let Some(ref serializers) = comm.serializers {
832 tracing::debug!("Dask communication serializers: {:?}", serializers);
833 }
834
835 if let Some(ref timeouts) = comm.timeouts {
836 let connect_timeout = timeouts.connect.unwrap_or(10.0);
837 tracing::debug!(
838 "Dask communication connect timeout: {} seconds",
839 connect_timeout
840 );
841
842 let tcp_timeout = timeouts.tcp.unwrap_or(30.0);
843 tracing::debug!("Dask communication TCP timeout: {} seconds", tcp_timeout);
844 }
845
846 if let Some(ref tcp) = comm.tcp {
847 let reuse_port = tcp.reuse_port.unwrap_or(false);
848 tracing::debug!("Dask TCP reuse port: {}", reuse_port);
849
850 let no_delay = tcp.no_delay.unwrap_or(true);
851 tracing::debug!("Dask TCP no delay: {}", no_delay);
852
853 let keep_alive = tcp.keep_alive.unwrap_or(false);
854 tracing::debug!("Dask TCP keep alive: {}", keep_alive);
855 }
856 }
857
858 if let Some(ref serialization) = distributed.serialization {
859 if let Some(ref compression) = serialization.compression {
860 tracing::debug!("Dask serialization compression: {:?}", compression);
861 }
862
863 let pickle_protocol = serialization.pickle_protocol.unwrap_or(4);
864 tracing::debug!("Dask serialization pickle protocol: {}", pickle_protocol);
865 }
866
867 if let Some(ref scheduling) = distributed.scheduling {
868 let work_stealing = scheduling.work_stealing.unwrap_or(true);
869 tracing::debug!("Dask scheduling work stealing: {}", work_stealing);
870
871 if work_stealing {
872 let interval = scheduling.work_stealing_interval.unwrap_or(0.1);
873 tracing::debug!("Dask work stealing interval: {} seconds", interval);
874 }
875
876 let unknown_task_duration = scheduling.unknown_task_duration.unwrap_or(0.5);
877 tracing::debug!(
878 "Dask unknown task duration: {} seconds",
879 unknown_task_duration
880 );
881 }
882
883 if let Some(ref diagnostics) = distributed.diagnostics {
884 let progress_bar = diagnostics.progress_bar.unwrap_or(true);
885 tracing::debug!("Dask diagnostics progress bar: {}", progress_bar);
886
887 let profile = diagnostics.profile.unwrap_or(false);
888 tracing::debug!("Dask diagnostics profile: {}", profile);
889
890 let memory_profiling = diagnostics.memory_profiling.unwrap_or(false);
891 tracing::debug!("Dask diagnostics memory profiling: {}", memory_profiling);
892 }
893 }
894 Ok(())
895 }
896
897 pub fn config(&self) -> &DaskConfig {
899 &self.config
900 }
901
902 pub fn stats(&self) -> &DaskStats {
904 &self.stats
905 }
906
907 pub fn is_initialized(&self) -> bool {
909 self.initialized
910 }
911
912 pub fn rank(&self) -> u32 {
914 self.rank
915 }
916
917 pub fn world_size(&self) -> u32 {
919 self.world_size
920 }
921
922 pub fn local_rank(&self) -> u32 {
924 self.local_rank
925 }
926
927 pub fn local_size(&self) -> u32 {
929 self.local_size
930 }
931
932 pub fn is_client_active(&self) -> bool {
934 self.client_active
935 }
936
937 pub fn submit_task(&mut self, task_name: &str, task_size: usize) -> TorshResult<String> {
939 if !self.initialized {
940 return Err(TorshDistributedError::BackendNotInitialized);
941 }
942
943 let start_time = std::time::Instant::now();
944
945 tracing::debug!("Submitting Dask task: {} ({} bytes)", task_name, task_size);
946
947 let task_id = format!("task_{}_{}", task_name, self.stats.tasks_executed);
949
950 self.stats.tasks_executed += 1;
952 let execution_time = start_time.elapsed().as_secs_f64();
953 self.stats.task_execution_time_sec += execution_time;
954 self.stats.average_task_duration_sec =
955 self.stats.task_execution_time_sec / self.stats.tasks_executed as f64;
956 self.stats.data_transferred_bytes += task_size as u64;
957
958 tracing::debug!("Dask task submitted: {} (ID: {})", task_name, task_id);
959 Ok(task_id)
960 }
961
962 pub fn compute(&mut self, collection_name: &str) -> TorshResult<()> {
964 if !self.initialized {
965 return Err(TorshDistributedError::BackendNotInitialized);
966 }
967
968 let start_time = std::time::Instant::now();
969
970 tracing::info!("Computing Dask collection: {}", collection_name);
971
972 let num_tasks = 10; for i in 0..num_tasks {
975 self.submit_task(&format!("{}_task_{}", collection_name, i), 1024)?;
976 }
977
978 let execution_time = start_time.elapsed().as_secs_f64();
979 tracing::info!(
980 "Dask collection computed: {} in {:.2}s",
981 collection_name,
982 execution_time
983 );
984 Ok(())
985 }
986
987 pub fn scale_cluster(&mut self, target_workers: u32) -> TorshResult<()> {
989 if !self.initialized {
990 return Err(TorshDistributedError::BackendNotInitialized);
991 }
992
993 tracing::info!("Scaling Dask cluster to {} workers", target_workers);
994
995 if let Some(ref cluster) = self.config.cluster {
996 if let Some(ref scaling) = cluster.scaling {
997 let minimum = scaling.minimum.unwrap_or(1);
998 let maximum = scaling.maximum.unwrap_or(100);
999
1000 if target_workers < minimum {
1001 return Err(TorshDistributedError::invalid_argument(
1002 "target_workers",
1003 format!("Cannot scale below minimum: {}", minimum),
1004 format!("At least {} workers", minimum),
1005 ));
1006 }
1007
1008 if target_workers > maximum {
1009 return Err(TorshDistributedError::invalid_argument(
1010 "target_workers",
1011 format!("Cannot scale above maximum: {}", maximum),
1012 format!("At most {} workers", maximum),
1013 ));
1014 }
1015 }
1016 }
1017
1018 self.stats.workers_connected = target_workers;
1019 tracing::info!("Dask cluster scaled to {} workers", target_workers);
1020 Ok(())
1021 }
1022
1023 pub fn handle_worker_failure(&mut self, worker_id: u32) -> TorshResult<()> {
1025 tracing::warn!("Dask worker {} failed", worker_id);
1026 self.stats.worker_failures += 1;
1027
1028 if self.stats.workers_connected > 0 {
1030 self.stats.workers_connected -= 1;
1031 }
1032
1033 if let Some(ref cluster) = self.config.cluster {
1035 if let Some(ref scaling) = cluster.scaling {
1036 let minimum = scaling.minimum.unwrap_or(1);
1037 if self.stats.workers_connected < minimum {
1038 tracing::info!("Auto-scaling Dask cluster due to worker failure");
1039 self.scale_cluster(minimum)?;
1040 }
1041 }
1042 }
1043
1044 Ok(())
1045 }
1046
1047 pub fn shutdown(&mut self) -> TorshResult<()> {
1049 if self.client_active {
1050 tracing::info!("Shutting down Dask integration");
1051 self.client_active = false;
1052 self.initialized = false;
1053 self.stats.workers_connected = 0;
1054 }
1055 Ok(())
1056 }
1057
1058 pub fn default_config() -> DaskConfig {
1060 DaskConfig {
1061 cluster: Some(DaskClusterConfig {
1062 cluster_type: DaskClusterType::Local,
1063 n_workers: Some(4),
1064 threads_per_worker: Some(2),
1065 memory_limit: Some("4GB".to_string()),
1066 processes: Some(false),
1067 dashboard_address: Some("127.0.0.1:8787".to_string()),
1068 silence_logs: Some(false),
1069 security: None,
1070 scaling: Some(DaskScalingConfig {
1071 minimum: Some(1),
1072 maximum: Some(10),
1073 target_cpu: Some(0.8),
1074 target_memory: Some(0.8),
1075 scale_up_threshold: Some(0.8),
1076 scale_down_threshold: Some(0.2),
1077 interval: Some(30.0),
1078 }),
1079 }),
1080 scheduler: Some(DaskSchedulerConfig {
1081 address: None,
1082 port: Some(8786),
1083 dashboard_port: Some(8787),
1084 bokeh_port: Some(8788),
1085 worker_timeout: Some(60.0),
1086 idle_timeout: Some(1800.0),
1087 transition_log_length: Some(100000),
1088 task_duration_overhead: Some(0.1),
1089 allowed_failures: Some(3),
1090 }),
1091 worker: Some(DaskWorkerConfig {
1092 nworkers: Some(4),
1093 nthreads: Some(2),
1094 memory_limit: Some("4GB".to_string()),
1095 worker_port: Some("40000:40100".to_string()),
1096 nanny_port: Some("40100:40200".to_string()),
1097 dashboard_port: Some(8789),
1098 death_timeout: Some(60.0),
1099 preload: None,
1100 env: None,
1101 resources: None,
1102 }),
1103 array: Some(DaskArrayConfig {
1104 chunk_size: Some("128MB".to_string()),
1105 backend: Some("numpy".to_string()),
1106 overlap: Some(0),
1107 boundary: Some("reflect".to_string()),
1108 trim: Some(true),
1109 rechunk_threshold: Some(4.0),
1110 }),
1111 dataframe: Some(DaskDataFrameConfig {
1112 partition_size: Some("128MB".to_string()),
1113 shuffle_method: Some(DaskShuffleMethod::Tasks),
1114 query_planning: Some(false),
1115 backend: Some("pandas".to_string()),
1116 optimize_index: Some(true),
1117 }),
1118 bag: Some(DaskBagConfig {
1119 partition_size: Some(134217728), compression: Some("gzip".to_string()),
1121 encoding: Some("utf-8".to_string()),
1122 linedelimiter: Some("\n".to_string()),
1123 }),
1124 ml: None,
1125 distributed: Some(DaskDistributedConfig {
1126 comm: Some(DaskCommConfig {
1127 compression: Some("lz4".to_string()),
1128 serializers: Some(vec!["pickle".to_string(), "msgpack".to_string()]),
1129 timeouts: Some(DaskTimeoutsConfig {
1130 connect: Some(10.0),
1131 tcp: Some(30.0),
1132 }),
1133 tcp: Some(DaskTcpConfig {
1134 reuse_port: Some(false),
1135 no_delay: Some(true),
1136 keep_alive: Some(false),
1137 }),
1138 }),
1139 serialization: Some(DaskSerializationConfig {
1140 compression: Some(vec!["lz4".to_string(), "zlib".to_string()]),
1141 default_serializers: Some(vec!["pickle".to_string()]),
1142 pickle_protocol: Some(4),
1143 }),
1144 client: Some(DaskClientConfig {
1145 heartbeat_interval: Some(5.0),
1146 scheduler_info_interval: Some(2.0),
1147 task_metadata: Some(vec!["task_name".to_string(), "worker".to_string()]),
1148 set_as_default: Some(true),
1149 }),
1150 scheduling: Some(DaskSchedulingConfig {
1151 work_stealing: Some(true),
1152 work_stealing_interval: Some(0.1),
1153 unknown_task_duration: Some(0.5),
1154 default_task_durations: None,
1155 }),
1156 diagnostics: Some(DaskDiagnosticsConfig {
1157 progress_bar: Some(true),
1158 profile: Some(false),
1159 memory_profiling: Some(false),
1160 task_stream: Some(false),
1161 resource_monitor: Some(false),
1162 }),
1163 }),
1164 }
1165 }
1166
1167 pub fn config_with_ml() -> DaskConfig {
1169 let mut config = Self::default_config();
1170
1171 config.ml = Some(DaskMLConfig {
1172 model_selection: Some(DaskMLModelSelectionConfig {
1173 cv_folds: Some(5),
1174 scoring: Some("accuracy".to_string()),
1175 n_jobs: Some(-1),
1176 search_method: Some(DaskMLSearchMethod::RandomSearch),
1177 }),
1178 preprocessing: Some(DaskMLPreprocessingConfig {
1179 standardization: Some("StandardScaler".to_string()),
1180 categorical_encoding: Some("OneHotEncoder".to_string()),
1181 feature_selection: Some("SelectKBest".to_string()),
1182 dimensionality_reduction: Some("PCA".to_string()),
1183 }),
1184 linear_model: Some(DaskMLLinearModelConfig {
1185 solver: Some("lbfgs".to_string()),
1186 alpha: Some(1.0),
1187 max_iter: Some(1000),
1188 tol: Some(1e-4),
1189 }),
1190 ensemble: Some(DaskMLEnsembleConfig {
1191 n_estimators: Some(100),
1192 bootstrap: Some(true),
1193 random_state: Some(42),
1194 oob_score: Some(true),
1195 }),
1196 cluster: Some(DaskMLClusterConfig {
1197 n_clusters: Some(8),
1198 init: Some("k-means++".to_string()),
1199 max_iter: Some(300),
1200 tol: Some(1e-4),
1201 n_init: Some(10),
1202 }),
1203 });
1204
1205 config
1206 }
1207
1208 pub fn config_with_large_scale(n_workers: u32, memory_per_worker: &str) -> DaskConfig {
1210 let mut config = Self::default_config();
1211
1212 if let Some(ref mut cluster) = config.cluster {
1213 cluster.n_workers = Some(n_workers);
1214 cluster.memory_limit = Some(memory_per_worker.to_string());
1215 cluster.processes = Some(true); if let Some(ref mut scaling) = cluster.scaling {
1218 scaling.minimum = Some(n_workers / 2);
1219 scaling.maximum = Some(n_workers * 2);
1220 }
1221 }
1222
1223 if let Some(ref mut worker) = config.worker {
1224 worker.nworkers = Some(n_workers);
1225 worker.memory_limit = Some(memory_per_worker.to_string());
1226 }
1227
1228 if let Some(ref mut distributed) = config.distributed {
1230 if let Some(ref mut comm) = distributed.comm {
1231 comm.compression = Some("zstd".to_string()); }
1233
1234 if let Some(ref mut scheduling) = distributed.scheduling {
1235 scheduling.work_stealing_interval = Some(0.5); }
1237 }
1238
1239 config
1240 }
1241}
1242
1243impl Default for DaskConfig {
1244 fn default() -> Self {
1245 DaskIntegration::default_config()
1246 }
1247}
1248
1249#[cfg(test)]
1250mod tests {
1251 use super::*;
1252
1253 #[test]
1254 fn test_dask_config_validation() {
1255 let config = DaskIntegration::default_config();
1256 let mut integration = DaskIntegration::new(config);
1257
1258 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1260 assert!(integration.is_initialized());
1261 assert!(integration.is_client_active());
1262 assert_eq!(integration.rank(), 0);
1263 assert_eq!(integration.world_size(), 4);
1264 assert_eq!(integration.local_rank(), 0);
1265 assert_eq!(integration.stats().workers_connected, 4);
1266 }
1267
1268 #[test]
1269 fn test_dask_task_submission() {
1270 let config = DaskIntegration::default_config();
1271 let mut integration = DaskIntegration::new(config);
1272
1273 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1274
1275 let task_id1 = integration.submit_task("compute_gradient", 1024).unwrap();
1277 let task_id2 = integration.submit_task("update_parameters", 2048).unwrap();
1278
1279 assert!(task_id1.contains("compute_gradient"));
1280 assert!(task_id2.contains("update_parameters"));
1281
1282 let stats = integration.stats();
1283 assert_eq!(stats.tasks_executed, 2);
1284 assert!(stats.task_execution_time_sec > 0.0);
1285 assert_eq!(stats.data_transferred_bytes, 3072);
1286 assert!(stats.average_task_duration_sec > 0.0);
1287 }
1288
1289 #[test]
1290 fn test_dask_compute_collection() {
1291 let config = DaskIntegration::default_config();
1292 let mut integration = DaskIntegration::new(config);
1293
1294 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1295
1296 assert!(integration.compute("training_dataset").is_ok());
1298
1299 let stats = integration.stats();
1300 assert_eq!(stats.tasks_executed, 10); assert!(stats.task_execution_time_sec > 0.0);
1302 }
1303
1304 #[test]
1305 fn test_dask_cluster_scaling() {
1306 let config = DaskIntegration::default_config();
1307 let mut integration = DaskIntegration::new(config);
1308
1309 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1310
1311 assert!(integration.scale_cluster(8).is_ok());
1313 assert_eq!(integration.stats().workers_connected, 8);
1314
1315 assert!(integration.scale_cluster(2).is_ok());
1317 assert_eq!(integration.stats().workers_connected, 2);
1318
1319 assert!(integration.scale_cluster(0).is_err());
1321
1322 assert!(integration.scale_cluster(20).is_err());
1324 }
1325
1326 #[test]
1327 fn test_dask_worker_failure_handling() {
1328 let config = DaskIntegration::default_config();
1329 let mut integration = DaskIntegration::new(config);
1330
1331 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1332
1333 assert!(integration.handle_worker_failure(1).is_ok());
1335 assert_eq!(integration.stats().worker_failures, 1);
1336 assert_eq!(integration.stats().workers_connected, 3);
1337
1338 assert!(integration.handle_worker_failure(2).is_ok());
1340 assert!(integration.handle_worker_failure(3).is_ok());
1341 assert_eq!(integration.stats().worker_failures, 3);
1342 assert_eq!(integration.stats().workers_connected, 1); }
1344
1345 #[test]
1346 fn test_dask_ml_config() {
1347 let config = DaskIntegration::config_with_ml();
1348 let mut integration = DaskIntegration::new(config);
1349
1350 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1351
1352 assert!(integration.config().ml.is_some());
1354
1355 if let Some(ref ml) = integration.config().ml {
1356 assert!(ml.model_selection.is_some());
1357 assert!(ml.preprocessing.is_some());
1358 assert!(ml.linear_model.is_some());
1359 assert!(ml.ensemble.is_some());
1360 assert!(ml.cluster.is_some());
1361
1362 if let Some(ref model_selection) = ml.model_selection {
1363 assert_eq!(model_selection.cv_folds, Some(5));
1364 assert_eq!(
1365 model_selection.search_method,
1366 Some(DaskMLSearchMethod::RandomSearch)
1367 );
1368 }
1369 }
1370 }
1371
1372 #[test]
1373 fn test_dask_large_scale_config() {
1374 let config = DaskIntegration::config_with_large_scale(16, "8GB");
1375 let mut integration = DaskIntegration::new(config);
1376
1377 assert!(integration.initialize(0, 16, 0, 4).is_ok());
1378
1379 if let Some(ref cluster) = integration.config().cluster {
1381 assert_eq!(cluster.n_workers, Some(16));
1382 assert_eq!(cluster.memory_limit, Some("8GB".to_string()));
1383 assert_eq!(cluster.processes, Some(true));
1384
1385 if let Some(ref scaling) = cluster.scaling {
1386 assert_eq!(scaling.minimum, Some(8));
1387 assert_eq!(scaling.maximum, Some(32));
1388 }
1389 }
1390 }
1391
1392 #[test]
1393 fn test_dask_shutdown() {
1394 let config = DaskIntegration::default_config();
1395 let mut integration = DaskIntegration::new(config);
1396
1397 assert!(integration.initialize(0, 4, 0, 2).is_ok());
1398 assert!(integration.is_client_active());
1399
1400 assert!(integration.shutdown().is_ok());
1401 assert!(!integration.is_client_active());
1402 assert!(!integration.is_initialized());
1403 assert_eq!(integration.stats().workers_connected, 0);
1404 }
1405
1406 #[test]
1407 fn test_dask_config_serialization() {
1408 let config = DaskIntegration::config_with_ml();
1409
1410 let json = serde_json::to_string(&config).unwrap();
1412 assert!(json.contains("Local"));
1413 assert!(json.contains("accuracy"));
1414 assert!(json.contains("RandomSearch"));
1415
1416 let deserialized: DaskConfig = serde_json::from_str(&json).unwrap();
1418 assert!(deserialized.cluster.is_some());
1419 assert!(deserialized.ml.is_some());
1420
1421 if let Some(cluster) = deserialized.cluster {
1422 assert_eq!(cluster.cluster_type, DaskClusterType::Local);
1423 }
1424 }
1425
1426 #[test]
1427 fn test_dask_invalid_config() {
1428 let mut config = DaskIntegration::default_config();
1429
1430 if let Some(ref mut cluster) = config.cluster {
1432 cluster.n_workers = Some(0); }
1434
1435 let mut integration = DaskIntegration::new(config);
1436
1437 assert!(integration.initialize(0, 4, 0, 2).is_err());
1439 }
1440
1441 #[test]
1442 fn test_dask_scaling_validation() {
1443 let mut config = DaskIntegration::default_config();
1444
1445 if let Some(ref mut cluster) = config.cluster {
1447 if let Some(ref mut scaling) = cluster.scaling {
1448 scaling.minimum = Some(10);
1449 scaling.maximum = Some(5); }
1451 }
1452
1453 let mut integration = DaskIntegration::new(config);
1454
1455 assert!(integration.initialize(0, 4, 0, 2).is_err());
1457 }
1458}