Skip to main content

torsh_distributed/
dask_integration.rs

1//! Dask integration for ToRSh distributed training
2//!
3//! This module provides compatibility with Dask's parallel computing framework,
4//! allowing users to leverage Dask's distributed computing capabilities
5//! with ToRSh distributed training.
6//!
7//! Dask is a flexible library for parallel computing in Python that provides:
8//! - Dask Array: Parallel NumPy-like arrays
9//! - Dask DataFrame: Parallel Pandas-like dataframes  
10//! - Dask Bag: Parallel collections for unstructured data
11//! - Dask Distributed: Distributed computing with task scheduling
12//! - Dask ML: Machine learning algorithms
13//! - Dask Gateway: Secure cluster management
14
15use crate::{TorshDistributedError, TorshResult};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::path::Path;
19
20/// Dask configuration compatible with ToRSh
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct DaskConfig {
23    /// Dask cluster configuration
24    pub cluster: Option<DaskClusterConfig>,
25    /// Dask scheduler configuration
26    pub scheduler: Option<DaskSchedulerConfig>,
27    /// Dask worker configuration
28    pub worker: Option<DaskWorkerConfig>,
29    /// Dask array configuration
30    pub array: Option<DaskArrayConfig>,
31    /// Dask dataframe configuration
32    pub dataframe: Option<DaskDataFrameConfig>,
33    /// Dask bag configuration
34    pub bag: Option<DaskBagConfig>,
35    /// Dask ML configuration
36    pub ml: Option<DaskMLConfig>,
37    /// Dask distributed configuration
38    pub distributed: Option<DaskDistributedConfig>,
39}
40
41/// Dask cluster configuration
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct DaskClusterConfig {
44    /// Cluster type
45    pub cluster_type: DaskClusterType,
46    /// Number of workers
47    pub n_workers: Option<u32>,
48    /// Threads per worker
49    pub threads_per_worker: Option<u32>,
50    /// Memory per worker
51    pub memory_limit: Option<String>,
52    /// Processes instead of threads
53    pub processes: Option<bool>,
54    /// Dashboard address
55    pub dashboard_address: Option<String>,
56    /// Silence logs
57    pub silence_logs: Option<bool>,
58    /// Security configuration
59    pub security: Option<DaskSecurityConfig>,
60    /// Cluster scaling configuration
61    pub scaling: Option<DaskScalingConfig>,
62}
63
64/// Dask cluster types
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
66pub enum DaskClusterType {
67    /// Local cluster
68    Local,
69    /// LocalCluster with processes
70    LocalProcess,
71    /// Kubernetes cluster
72    Kubernetes,
73    /// SLURM cluster
74    Slurm,
75    /// PBS cluster
76    PBS,
77    /// SGE cluster
78    SGE,
79    /// Custom cluster
80    Custom,
81}
82
83/// Dask security configuration
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct DaskSecurityConfig {
86    /// TLS certificate file
87    pub tls_cert: Option<String>,
88    /// TLS key file
89    pub tls_key: Option<String>,
90    /// TLS CA file
91    pub tls_ca_file: Option<String>,
92    /// Require encryption
93    pub require_encryption: Option<bool>,
94}
95
96/// Dask scaling configuration
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct DaskScalingConfig {
99    /// Minimum workers
100    pub minimum: Option<u32>,
101    /// Maximum workers
102    pub maximum: Option<u32>,
103    /// Target CPU utilization
104    pub target_cpu: Option<f32>,
105    /// Target memory utilization
106    pub target_memory: Option<f32>,
107    /// Scale up threshold
108    pub scale_up_threshold: Option<f32>,
109    /// Scale down threshold
110    pub scale_down_threshold: Option<f32>,
111    /// Interval for scaling decisions
112    pub interval: Option<f32>,
113}
114
115/// Dask scheduler configuration
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct DaskSchedulerConfig {
118    /// Scheduler address
119    pub address: Option<String>,
120    /// Scheduler port
121    pub port: Option<u16>,
122    /// Dashboard port
123    pub dashboard_port: Option<u16>,
124    /// Bokeh port
125    pub bokeh_port: Option<u16>,
126    /// Worker timeout
127    pub worker_timeout: Option<f32>,
128    /// Idle timeout
129    pub idle_timeout: Option<f32>,
130    /// Transition log length
131    pub transition_log_length: Option<u32>,
132    /// Task duration overhead
133    pub task_duration_overhead: Option<f32>,
134    /// Allowed failures
135    pub allowed_failures: Option<u32>,
136}
137
138/// Dask worker configuration
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct DaskWorkerConfig {
141    /// Number of workers
142    pub nworkers: Option<u32>,
143    /// Threads per worker
144    pub nthreads: Option<u32>,
145    /// Memory limit per worker
146    pub memory_limit: Option<String>,
147    /// Worker port range
148    pub worker_port: Option<String>,
149    /// Nanny port range
150    pub nanny_port: Option<String>,
151    /// Dashboard port
152    pub dashboard_port: Option<u16>,
153    /// Death timeout
154    pub death_timeout: Option<f32>,
155    /// Preload modules
156    pub preload: Option<Vec<String>>,
157    /// Environment variables
158    pub env: Option<HashMap<String, String>>,
159    /// Resources
160    pub resources: Option<HashMap<String, f32>>,
161}
162
163/// Dask array configuration
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct DaskArrayConfig {
166    /// Default chunk size
167    pub chunk_size: Option<String>,
168    /// Array backend
169    pub backend: Option<String>,
170    /// Overlap for sliding window operations
171    pub overlap: Option<u32>,
172    /// Boundary conditions
173    pub boundary: Option<String>,
174    /// Trim excess data
175    pub trim: Option<bool>,
176    /// Rechunk threshold
177    pub rechunk_threshold: Option<f32>,
178}
179
180/// Dask dataframe configuration
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct DaskDataFrameConfig {
183    /// Default partition size
184    pub partition_size: Option<String>,
185    /// Shuffle method
186    pub shuffle_method: Option<DaskShuffleMethod>,
187    /// Query planning
188    pub query_planning: Option<bool>,
189    /// Dataframe backend
190    pub backend: Option<String>,
191    /// Index optimization
192    pub optimize_index: Option<bool>,
193}
194
195/// Dask shuffle methods
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
197pub enum DaskShuffleMethod {
198    /// Disk-based shuffle
199    Disk,
200    /// Tasks-based shuffle
201    Tasks,
202    /// P2P shuffle
203    P2P,
204}
205
206/// Dask bag configuration
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct DaskBagConfig {
209    /// Default partition size
210    pub partition_size: Option<u64>,
211    /// Compression for storage
212    pub compression: Option<String>,
213    /// Text encoding
214    pub encoding: Option<String>,
215    /// Split by lines
216    pub linedelimiter: Option<String>,
217}
218
219/// Dask ML configuration
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct DaskMLConfig {
222    /// Model selection configuration
223    pub model_selection: Option<DaskMLModelSelectionConfig>,
224    /// Preprocessing configuration
225    pub preprocessing: Option<DaskMLPreprocessingConfig>,
226    /// Linear models configuration
227    pub linear_model: Option<DaskMLLinearModelConfig>,
228    /// Ensemble configuration
229    pub ensemble: Option<DaskMLEnsembleConfig>,
230    /// Clustering configuration
231    pub cluster: Option<DaskMLClusterConfig>,
232}
233
234/// Dask ML model selection configuration
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct DaskMLModelSelectionConfig {
237    /// Cross-validation folds
238    pub cv_folds: Option<u32>,
239    /// Scoring metric
240    pub scoring: Option<String>,
241    /// N jobs for parallel execution
242    pub n_jobs: Option<i32>,
243    /// Hyperparameter search method
244    pub search_method: Option<DaskMLSearchMethod>,
245}
246
247/// Dask ML search methods
248#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
249pub enum DaskMLSearchMethod {
250    /// Grid search
251    GridSearch,
252    /// Random search
253    RandomSearch,
254    /// Successive halving
255    SuccessiveHalving,
256    /// Hyperband
257    Hyperband,
258}
259
260/// Dask ML preprocessing configuration
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct DaskMLPreprocessingConfig {
263    /// Standardization method
264    pub standardization: Option<String>,
265    /// Categorical encoding
266    pub categorical_encoding: Option<String>,
267    /// Feature selection
268    pub feature_selection: Option<String>,
269    /// Dimensionality reduction
270    pub dimensionality_reduction: Option<String>,
271}
272
273/// Dask ML linear model configuration
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct DaskMLLinearModelConfig {
276    /// Solver for linear models
277    pub solver: Option<String>,
278    /// Regularization parameter
279    pub alpha: Option<f32>,
280    /// Maximum iterations
281    pub max_iter: Option<u32>,
282    /// Tolerance for convergence
283    pub tol: Option<f32>,
284}
285
286/// Dask ML ensemble configuration
287#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct DaskMLEnsembleConfig {
289    /// Number of estimators
290    pub n_estimators: Option<u32>,
291    /// Bootstrap sampling
292    pub bootstrap: Option<bool>,
293    /// Random state
294    pub random_state: Option<u32>,
295    /// Out of bag score
296    pub oob_score: Option<bool>,
297}
298
299/// Dask ML clustering configuration
300#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct DaskMLClusterConfig {
302    /// Number of clusters
303    pub n_clusters: Option<u32>,
304    /// Initialization method
305    pub init: Option<String>,
306    /// Maximum iterations
307    pub max_iter: Option<u32>,
308    /// Tolerance
309    pub tol: Option<f32>,
310    /// Number of init runs
311    pub n_init: Option<u32>,
312}
313
314/// Dask distributed configuration
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct DaskDistributedConfig {
317    /// Communication configuration
318    pub comm: Option<DaskCommConfig>,
319    /// Serialization configuration
320    pub serialization: Option<DaskSerializationConfig>,
321    /// Client configuration
322    pub client: Option<DaskClientConfig>,
323    /// Task scheduling configuration
324    pub scheduling: Option<DaskSchedulingConfig>,
325    /// Diagnostics configuration
326    pub diagnostics: Option<DaskDiagnosticsConfig>,
327}
328
329/// Dask communication configuration
330#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct DaskCommConfig {
332    /// Compression algorithm
333    pub compression: Option<String>,
334    /// Default serializers
335    pub serializers: Option<Vec<String>>,
336    /// Timeouts
337    pub timeouts: Option<DaskTimeoutsConfig>,
338    /// TCP configuration
339    pub tcp: Option<DaskTcpConfig>,
340}
341
342/// Dask timeouts configuration
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct DaskTimeoutsConfig {
345    /// Connect timeout
346    pub connect: Option<f32>,
347    /// TCP timeout
348    pub tcp: Option<f32>,
349}
350
351/// Dask TCP configuration
352#[derive(Debug, Clone, Serialize, Deserialize)]
353pub struct DaskTcpConfig {
354    /// Reuse port
355    pub reuse_port: Option<bool>,
356    /// No delay
357    pub no_delay: Option<bool>,
358    /// Keep alive
359    pub keep_alive: Option<bool>,
360}
361
362/// Dask serialization configuration
363#[derive(Debug, Clone, Serialize, Deserialize)]
364pub struct DaskSerializationConfig {
365    /// Compression algorithms
366    pub compression: Option<Vec<String>>,
367    /// Default serializers
368    pub default_serializers: Option<Vec<String>>,
369    /// Pickle protocol
370    pub pickle_protocol: Option<u32>,
371}
372
373/// Dask client configuration
374#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct DaskClientConfig {
376    /// Heartbeat interval
377    pub heartbeat_interval: Option<f32>,
378    /// Scheduler info interval
379    pub scheduler_info_interval: Option<f32>,
380    /// Task metadata
381    pub task_metadata: Option<Vec<String>>,
382    /// Set as default
383    pub set_as_default: Option<bool>,
384}
385
386/// Dask scheduling configuration
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct DaskSchedulingConfig {
389    /// Work stealing
390    pub work_stealing: Option<bool>,
391    /// Work stealing interval
392    pub work_stealing_interval: Option<f32>,
393    /// Unknown task duration
394    pub unknown_task_duration: Option<f32>,
395    /// Default task durations
396    pub default_task_durations: Option<HashMap<String, f32>>,
397}
398
399/// Dask diagnostics configuration
400#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct DaskDiagnosticsConfig {
402    /// Progress bar
403    pub progress_bar: Option<bool>,
404    /// Profile
405    pub profile: Option<bool>,
406    /// Memory profiling
407    pub memory_profiling: Option<bool>,
408    /// Task stream
409    pub task_stream: Option<bool>,
410    /// Resource monitor
411    pub resource_monitor: Option<bool>,
412}
413
414/// Dask integration statistics
415#[derive(Debug, Clone, Default)]
416pub struct DaskStats {
417    /// Number of tasks executed
418    pub tasks_executed: u64,
419    /// Total task execution time (seconds)
420    pub task_execution_time_sec: f64,
421    /// Number of workers connected
422    pub workers_connected: u32,
423    /// Total data transferred (bytes)
424    pub data_transferred_bytes: u64,
425    /// Number of task retries
426    pub task_retries: u64,
427    /// Number of worker failures
428    pub worker_failures: u64,
429    /// Memory usage (bytes)
430    pub memory_usage_bytes: u64,
431    /// CPU utilization
432    pub cpu_utilization: f64,
433    /// Network bandwidth (bytes/sec)
434    pub network_bandwidth_bytes_per_sec: f64,
435    /// Average task duration (seconds)
436    pub average_task_duration_sec: f64,
437}
438
439/// Dask compatibility integration
440pub struct DaskIntegration {
441    /// Configuration
442    config: DaskConfig,
443    /// Statistics
444    stats: DaskStats,
445    /// Initialization status
446    initialized: bool,
447    /// Process rank
448    rank: u32,
449    /// World size
450    world_size: u32,
451    /// Local rank
452    local_rank: u32,
453    /// Local size
454    local_size: u32,
455    /// Dask client active
456    client_active: bool,
457}
458
459impl DaskIntegration {
460    /// Create a new Dask integration
461    pub fn new(config: DaskConfig) -> Self {
462        Self {
463            config,
464            stats: DaskStats::default(),
465            initialized: false,
466            rank: 0,
467            world_size: 1,
468            local_rank: 0,
469            local_size: 1,
470            client_active: false,
471        }
472    }
473
474    /// Load configuration from JSON file
475    pub fn from_file<P: AsRef<Path>>(path: P) -> TorshResult<Self> {
476        let content = std::fs::read_to_string(path).map_err(|e| {
477            TorshDistributedError::configuration_error(format!(
478                "Failed to read Dask config file: {}",
479                e
480            ))
481        })?;
482
483        let config: DaskConfig = serde_json::from_str(&content).map_err(|e| {
484            TorshDistributedError::configuration_error(format!(
485                "Failed to parse Dask config: {}",
486                e
487            ))
488        })?;
489
490        Ok(Self::new(config))
491    }
492
493    /// Initialize Dask integration
494    pub fn initialize(
495        &mut self,
496        rank: u32,
497        world_size: u32,
498        local_rank: u32,
499        local_size: u32,
500    ) -> TorshResult<()> {
501        if self.initialized {
502            return Err(TorshDistributedError::configuration_error(
503                "Dask integration already initialized",
504            ));
505        }
506
507        self.rank = rank;
508        self.world_size = world_size;
509        self.local_rank = local_rank;
510        self.local_size = local_size;
511
512        self.validate_config()?;
513        self.setup_cluster()?;
514        self.setup_scheduler()?;
515        self.setup_workers()?;
516        self.setup_client()?;
517        self.setup_ml()?;
518        self.setup_distributed()?;
519
520        self.initialized = true;
521        self.client_active = true;
522
523        tracing::info!(
524            "Dask integration initialized - rank: {}, world_size: {}, local_rank: {}",
525            self.rank,
526            self.world_size,
527            self.local_rank
528        );
529
530        Ok(())
531    }
532
533    /// Validate Dask configuration
534    fn validate_config(&self) -> TorshResult<()> {
535        // Validate cluster configuration
536        if let Some(ref cluster) = self.config.cluster {
537            if let Some(n_workers) = cluster.n_workers {
538                if n_workers == 0 {
539                    return Err(TorshDistributedError::configuration_error(
540                        "Dask cluster n_workers must be greater than 0",
541                    ));
542                }
543            }
544
545            if let Some(threads_per_worker) = cluster.threads_per_worker {
546                if threads_per_worker == 0 {
547                    return Err(TorshDistributedError::configuration_error(
548                        "Dask cluster threads_per_worker must be greater than 0",
549                    ));
550                }
551            }
552
553            if let Some(ref scaling) = cluster.scaling {
554                if let Some(minimum) = scaling.minimum {
555                    if let Some(maximum) = scaling.maximum {
556                        if minimum > maximum {
557                            return Err(TorshDistributedError::configuration_error(
558                                "Dask scaling minimum workers cannot exceed maximum workers",
559                            ));
560                        }
561                    }
562                }
563            }
564        }
565
566        // Validate scheduler configuration
567        if let Some(ref scheduler) = self.config.scheduler {
568            if let Some(port) = scheduler.port {
569                if port == 0 {
570                    return Err(TorshDistributedError::configuration_error(
571                        "Dask scheduler port must be greater than 0",
572                    ));
573                }
574            }
575
576            if let Some(worker_timeout) = scheduler.worker_timeout {
577                if worker_timeout <= 0.0 {
578                    return Err(TorshDistributedError::configuration_error(
579                        "Dask scheduler worker_timeout must be greater than 0",
580                    ));
581                }
582            }
583        }
584
585        // Validate worker configuration
586        if let Some(ref worker) = self.config.worker {
587            if let Some(nworkers) = worker.nworkers {
588                if nworkers == 0 {
589                    return Err(TorshDistributedError::configuration_error(
590                        "Dask worker nworkers must be greater than 0",
591                    ));
592                }
593            }
594
595            if let Some(nthreads) = worker.nthreads {
596                if nthreads == 0 {
597                    return Err(TorshDistributedError::configuration_error(
598                        "Dask worker nthreads must be greater than 0",
599                    ));
600                }
601            }
602        }
603
604        Ok(())
605    }
606
607    /// Setup Dask cluster
608    fn setup_cluster(&self) -> TorshResult<()> {
609        if let Some(ref cluster) = self.config.cluster {
610            tracing::info!("Setting up Dask cluster: {:?}", cluster.cluster_type);
611
612            let n_workers = cluster.n_workers.unwrap_or(4);
613            tracing::debug!("Dask cluster workers: {}", n_workers);
614
615            let threads_per_worker = cluster.threads_per_worker.unwrap_or(2);
616            tracing::debug!("Dask threads per worker: {}", threads_per_worker);
617
618            if let Some(ref memory_limit) = cluster.memory_limit {
619                tracing::debug!("Dask memory limit per worker: {}", memory_limit);
620            }
621
622            let processes = cluster.processes.unwrap_or(false);
623            tracing::debug!("Dask use processes: {}", processes);
624
625            if let Some(ref dashboard_address) = cluster.dashboard_address {
626                tracing::debug!("Dask dashboard address: {}", dashboard_address);
627            }
628
629            let silence_logs = cluster.silence_logs.unwrap_or(false);
630            tracing::debug!("Dask silence logs: {}", silence_logs);
631
632            if let Some(ref security) = cluster.security {
633                tracing::debug!("Dask security enabled");
634                if let Some(ref tls_cert) = security.tls_cert {
635                    tracing::debug!("Dask TLS cert: {}", tls_cert);
636                }
637                if let Some(ref tls_key) = security.tls_key {
638                    tracing::debug!("Dask TLS key: {}", tls_key);
639                }
640                let require_encryption = security.require_encryption.unwrap_or(false);
641                tracing::debug!("Dask require encryption: {}", require_encryption);
642            }
643
644            if let Some(ref scaling) = cluster.scaling {
645                let minimum = scaling.minimum.unwrap_or(1);
646                let maximum = scaling.maximum.unwrap_or(n_workers * 2);
647                tracing::debug!("Dask scaling: {} - {} workers", minimum, maximum);
648
649                let target_cpu = scaling.target_cpu.unwrap_or(0.8);
650                tracing::debug!("Dask target CPU utilization: {}", target_cpu);
651
652                let interval = scaling.interval.unwrap_or(30.0);
653                tracing::debug!("Dask scaling interval: {} seconds", interval);
654            }
655        }
656        Ok(())
657    }
658
659    /// Setup Dask scheduler
660    fn setup_scheduler(&self) -> TorshResult<()> {
661        if let Some(ref scheduler) = self.config.scheduler {
662            tracing::info!("Setting up Dask scheduler");
663
664            if let Some(ref address) = scheduler.address {
665                tracing::debug!("Dask scheduler address: {}", address);
666            }
667
668            let port = scheduler.port.unwrap_or(8786);
669            tracing::debug!("Dask scheduler port: {}", port);
670
671            let dashboard_port = scheduler.dashboard_port.unwrap_or(8787);
672            tracing::debug!("Dask scheduler dashboard port: {}", dashboard_port);
673
674            let worker_timeout = scheduler.worker_timeout.unwrap_or(60.0);
675            tracing::debug!("Dask scheduler worker timeout: {} seconds", worker_timeout);
676
677            let idle_timeout = scheduler.idle_timeout.unwrap_or(1800.0);
678            tracing::debug!("Dask scheduler idle timeout: {} seconds", idle_timeout);
679
680            let allowed_failures = scheduler.allowed_failures.unwrap_or(3);
681            tracing::debug!("Dask scheduler allowed failures: {}", allowed_failures);
682        }
683        Ok(())
684    }
685
686    /// Setup Dask workers
687    fn setup_workers(&mut self) -> TorshResult<()> {
688        if let Some(ref worker) = self.config.worker {
689            tracing::info!("Setting up Dask workers");
690
691            let nworkers = worker.nworkers.unwrap_or(4);
692            tracing::debug!("Dask number of workers: {}", nworkers);
693
694            let nthreads = worker.nthreads.unwrap_or(2);
695            tracing::debug!("Dask threads per worker: {}", nthreads);
696
697            if let Some(ref memory_limit) = worker.memory_limit {
698                tracing::debug!("Dask worker memory limit: {}", memory_limit);
699            }
700
701            if let Some(ref worker_port) = worker.worker_port {
702                tracing::debug!("Dask worker port range: {}", worker_port);
703            }
704
705            if let Some(ref nanny_port) = worker.nanny_port {
706                tracing::debug!("Dask nanny port range: {}", nanny_port);
707            }
708
709            let death_timeout = worker.death_timeout.unwrap_or(60.0);
710            tracing::debug!("Dask worker death timeout: {} seconds", death_timeout);
711
712            if let Some(ref preload) = worker.preload {
713                tracing::debug!("Dask worker preload modules: {:?}", preload);
714            }
715
716            if let Some(ref env) = worker.env {
717                tracing::debug!("Dask worker environment variables: {:?}", env);
718            }
719
720            if let Some(ref resources) = worker.resources {
721                tracing::debug!("Dask worker resources: {:?}", resources);
722            }
723        }
724
725        // Update stats
726        self.stats.workers_connected = self
727            .config
728            .worker
729            .as_ref()
730            .and_then(|w| w.nworkers)
731            .unwrap_or(4);
732
733        Ok(())
734    }
735
736    /// Setup Dask client
737    fn setup_client(&self) -> TorshResult<()> {
738        if let Some(ref distributed) = self.config.distributed {
739            if let Some(ref client) = distributed.client {
740                tracing::info!("Setting up Dask client");
741
742                let heartbeat_interval = client.heartbeat_interval.unwrap_or(5.0);
743                tracing::debug!(
744                    "Dask client heartbeat interval: {} seconds",
745                    heartbeat_interval
746                );
747
748                let scheduler_info_interval = client.scheduler_info_interval.unwrap_or(2.0);
749                tracing::debug!(
750                    "Dask client scheduler info interval: {} seconds",
751                    scheduler_info_interval
752                );
753
754                if let Some(ref task_metadata) = client.task_metadata {
755                    tracing::debug!("Dask client task metadata: {:?}", task_metadata);
756                }
757
758                let set_as_default = client.set_as_default.unwrap_or(true);
759                tracing::debug!("Dask client set as default: {}", set_as_default);
760            }
761        }
762        Ok(())
763    }
764
765    /// Setup Dask ML
766    fn setup_ml(&self) -> TorshResult<()> {
767        if let Some(ref ml) = self.config.ml {
768            tracing::info!("Setting up Dask ML");
769
770            if let Some(ref model_selection) = ml.model_selection {
771                let cv_folds = model_selection.cv_folds.unwrap_or(5);
772                tracing::debug!("Dask ML cross-validation folds: {}", cv_folds);
773
774                if let Some(ref scoring) = model_selection.scoring {
775                    tracing::debug!("Dask ML scoring metric: {}", scoring);
776                }
777
778                if let Some(search_method) = model_selection.search_method {
779                    tracing::debug!("Dask ML search method: {:?}", search_method);
780                }
781            }
782
783            if let Some(ref preprocessing) = ml.preprocessing {
784                if let Some(ref standardization) = preprocessing.standardization {
785                    tracing::debug!("Dask ML standardization: {}", standardization);
786                }
787
788                if let Some(ref encoding) = preprocessing.categorical_encoding {
789                    tracing::debug!("Dask ML categorical encoding: {}", encoding);
790                }
791            }
792
793            if let Some(ref linear_model) = ml.linear_model {
794                if let Some(ref solver) = linear_model.solver {
795                    tracing::debug!("Dask ML linear model solver: {}", solver);
796                }
797
798                let max_iter = linear_model.max_iter.unwrap_or(1000);
799                tracing::debug!("Dask ML linear model max iterations: {}", max_iter);
800            }
801
802            if let Some(ref ensemble) = ml.ensemble {
803                let n_estimators = ensemble.n_estimators.unwrap_or(100);
804                tracing::debug!("Dask ML ensemble estimators: {}", n_estimators);
805
806                let bootstrap = ensemble.bootstrap.unwrap_or(true);
807                tracing::debug!("Dask ML ensemble bootstrap: {}", bootstrap);
808            }
809
810            if let Some(ref cluster) = ml.cluster {
811                let n_clusters = cluster.n_clusters.unwrap_or(8);
812                tracing::debug!("Dask ML clustering clusters: {}", n_clusters);
813
814                let max_iter = cluster.max_iter.unwrap_or(300);
815                tracing::debug!("Dask ML clustering max iterations: {}", max_iter);
816            }
817        }
818        Ok(())
819    }
820
821    /// Setup Dask distributed
822    fn setup_distributed(&self) -> TorshResult<()> {
823        if let Some(ref distributed) = self.config.distributed {
824            tracing::info!("Setting up Dask distributed");
825
826            if let Some(ref comm) = distributed.comm {
827                if let Some(ref compression) = comm.compression {
828                    tracing::debug!("Dask communication compression: {}", compression);
829                }
830
831                if let Some(ref serializers) = comm.serializers {
832                    tracing::debug!("Dask communication serializers: {:?}", serializers);
833                }
834
835                if let Some(ref timeouts) = comm.timeouts {
836                    let connect_timeout = timeouts.connect.unwrap_or(10.0);
837                    tracing::debug!(
838                        "Dask communication connect timeout: {} seconds",
839                        connect_timeout
840                    );
841
842                    let tcp_timeout = timeouts.tcp.unwrap_or(30.0);
843                    tracing::debug!("Dask communication TCP timeout: {} seconds", tcp_timeout);
844                }
845
846                if let Some(ref tcp) = comm.tcp {
847                    let reuse_port = tcp.reuse_port.unwrap_or(false);
848                    tracing::debug!("Dask TCP reuse port: {}", reuse_port);
849
850                    let no_delay = tcp.no_delay.unwrap_or(true);
851                    tracing::debug!("Dask TCP no delay: {}", no_delay);
852
853                    let keep_alive = tcp.keep_alive.unwrap_or(false);
854                    tracing::debug!("Dask TCP keep alive: {}", keep_alive);
855                }
856            }
857
858            if let Some(ref serialization) = distributed.serialization {
859                if let Some(ref compression) = serialization.compression {
860                    tracing::debug!("Dask serialization compression: {:?}", compression);
861                }
862
863                let pickle_protocol = serialization.pickle_protocol.unwrap_or(4);
864                tracing::debug!("Dask serialization pickle protocol: {}", pickle_protocol);
865            }
866
867            if let Some(ref scheduling) = distributed.scheduling {
868                let work_stealing = scheduling.work_stealing.unwrap_or(true);
869                tracing::debug!("Dask scheduling work stealing: {}", work_stealing);
870
871                if work_stealing {
872                    let interval = scheduling.work_stealing_interval.unwrap_or(0.1);
873                    tracing::debug!("Dask work stealing interval: {} seconds", interval);
874                }
875
876                let unknown_task_duration = scheduling.unknown_task_duration.unwrap_or(0.5);
877                tracing::debug!(
878                    "Dask unknown task duration: {} seconds",
879                    unknown_task_duration
880                );
881            }
882
883            if let Some(ref diagnostics) = distributed.diagnostics {
884                let progress_bar = diagnostics.progress_bar.unwrap_or(true);
885                tracing::debug!("Dask diagnostics progress bar: {}", progress_bar);
886
887                let profile = diagnostics.profile.unwrap_or(false);
888                tracing::debug!("Dask diagnostics profile: {}", profile);
889
890                let memory_profiling = diagnostics.memory_profiling.unwrap_or(false);
891                tracing::debug!("Dask diagnostics memory profiling: {}", memory_profiling);
892            }
893        }
894        Ok(())
895    }
896
897    /// Get the current configuration
898    pub fn config(&self) -> &DaskConfig {
899        &self.config
900    }
901
902    /// Get current statistics
903    pub fn stats(&self) -> &DaskStats {
904        &self.stats
905    }
906
907    /// Check if Dask integration is initialized
908    pub fn is_initialized(&self) -> bool {
909        self.initialized
910    }
911
912    /// Get current rank
913    pub fn rank(&self) -> u32 {
914        self.rank
915    }
916
917    /// Get world size
918    pub fn world_size(&self) -> u32 {
919        self.world_size
920    }
921
922    /// Get local rank
923    pub fn local_rank(&self) -> u32 {
924        self.local_rank
925    }
926
927    /// Get local size
928    pub fn local_size(&self) -> u32 {
929        self.local_size
930    }
931
932    /// Check if Dask client is active
933    pub fn is_client_active(&self) -> bool {
934        self.client_active
935    }
936
937    /// Submit task to Dask cluster
938    pub fn submit_task(&mut self, task_name: &str, task_size: usize) -> TorshResult<String> {
939        if !self.initialized {
940            return Err(TorshDistributedError::BackendNotInitialized);
941        }
942
943        let start_time = std::time::Instant::now();
944
945        tracing::debug!("Submitting Dask task: {} ({} bytes)", task_name, task_size);
946
947        // Simulate task execution
948        let task_id = format!("task_{}_{}", task_name, self.stats.tasks_executed);
949
950        // Update statistics
951        self.stats.tasks_executed += 1;
952        let execution_time = start_time.elapsed().as_secs_f64();
953        self.stats.task_execution_time_sec += execution_time;
954        self.stats.average_task_duration_sec =
955            self.stats.task_execution_time_sec / self.stats.tasks_executed as f64;
956        self.stats.data_transferred_bytes += task_size as u64;
957
958        tracing::debug!("Dask task submitted: {} (ID: {})", task_name, task_id);
959        Ok(task_id)
960    }
961
962    /// Compute Dask collection
963    pub fn compute(&mut self, collection_name: &str) -> TorshResult<()> {
964        if !self.initialized {
965            return Err(TorshDistributedError::BackendNotInitialized);
966        }
967
968        let start_time = std::time::Instant::now();
969
970        tracing::info!("Computing Dask collection: {}", collection_name);
971
972        // Simulate computation
973        let num_tasks = 10; // Simulate breaking down into tasks
974        for i in 0..num_tasks {
975            self.submit_task(&format!("{}_task_{}", collection_name, i), 1024)?;
976        }
977
978        let execution_time = start_time.elapsed().as_secs_f64();
979        tracing::info!(
980            "Dask collection computed: {} in {:.2}s",
981            collection_name,
982            execution_time
983        );
984        Ok(())
985    }
986
987    /// Scale Dask cluster
988    pub fn scale_cluster(&mut self, target_workers: u32) -> TorshResult<()> {
989        if !self.initialized {
990            return Err(TorshDistributedError::BackendNotInitialized);
991        }
992
993        tracing::info!("Scaling Dask cluster to {} workers", target_workers);
994
995        if let Some(ref cluster) = self.config.cluster {
996            if let Some(ref scaling) = cluster.scaling {
997                let minimum = scaling.minimum.unwrap_or(1);
998                let maximum = scaling.maximum.unwrap_or(100);
999
1000                if target_workers < minimum {
1001                    return Err(TorshDistributedError::invalid_argument(
1002                        "target_workers",
1003                        format!("Cannot scale below minimum: {}", minimum),
1004                        format!("At least {} workers", minimum),
1005                    ));
1006                }
1007
1008                if target_workers > maximum {
1009                    return Err(TorshDistributedError::invalid_argument(
1010                        "target_workers",
1011                        format!("Cannot scale above maximum: {}", maximum),
1012                        format!("At most {} workers", maximum),
1013                    ));
1014                }
1015            }
1016        }
1017
1018        self.stats.workers_connected = target_workers;
1019        tracing::info!("Dask cluster scaled to {} workers", target_workers);
1020        Ok(())
1021    }
1022
1023    /// Handle worker failure
1024    pub fn handle_worker_failure(&mut self, worker_id: u32) -> TorshResult<()> {
1025        tracing::warn!("Dask worker {} failed", worker_id);
1026        self.stats.worker_failures += 1;
1027
1028        // Decrease worker count
1029        if self.stats.workers_connected > 0 {
1030            self.stats.workers_connected -= 1;
1031        }
1032
1033        // Auto-scale if configured
1034        if let Some(ref cluster) = self.config.cluster {
1035            if let Some(ref scaling) = cluster.scaling {
1036                let minimum = scaling.minimum.unwrap_or(1);
1037                if self.stats.workers_connected < minimum {
1038                    tracing::info!("Auto-scaling Dask cluster due to worker failure");
1039                    self.scale_cluster(minimum)?;
1040                }
1041            }
1042        }
1043
1044        Ok(())
1045    }
1046
1047    /// Shutdown Dask integration
1048    pub fn shutdown(&mut self) -> TorshResult<()> {
1049        if self.client_active {
1050            tracing::info!("Shutting down Dask integration");
1051            self.client_active = false;
1052            self.initialized = false;
1053            self.stats.workers_connected = 0;
1054        }
1055        Ok(())
1056    }
1057
1058    /// Create a default Dask configuration
1059    pub fn default_config() -> DaskConfig {
1060        DaskConfig {
1061            cluster: Some(DaskClusterConfig {
1062                cluster_type: DaskClusterType::Local,
1063                n_workers: Some(4),
1064                threads_per_worker: Some(2),
1065                memory_limit: Some("4GB".to_string()),
1066                processes: Some(false),
1067                dashboard_address: Some("127.0.0.1:8787".to_string()),
1068                silence_logs: Some(false),
1069                security: None,
1070                scaling: Some(DaskScalingConfig {
1071                    minimum: Some(1),
1072                    maximum: Some(10),
1073                    target_cpu: Some(0.8),
1074                    target_memory: Some(0.8),
1075                    scale_up_threshold: Some(0.8),
1076                    scale_down_threshold: Some(0.2),
1077                    interval: Some(30.0),
1078                }),
1079            }),
1080            scheduler: Some(DaskSchedulerConfig {
1081                address: None,
1082                port: Some(8786),
1083                dashboard_port: Some(8787),
1084                bokeh_port: Some(8788),
1085                worker_timeout: Some(60.0),
1086                idle_timeout: Some(1800.0),
1087                transition_log_length: Some(100000),
1088                task_duration_overhead: Some(0.1),
1089                allowed_failures: Some(3),
1090            }),
1091            worker: Some(DaskWorkerConfig {
1092                nworkers: Some(4),
1093                nthreads: Some(2),
1094                memory_limit: Some("4GB".to_string()),
1095                worker_port: Some("40000:40100".to_string()),
1096                nanny_port: Some("40100:40200".to_string()),
1097                dashboard_port: Some(8789),
1098                death_timeout: Some(60.0),
1099                preload: None,
1100                env: None,
1101                resources: None,
1102            }),
1103            array: Some(DaskArrayConfig {
1104                chunk_size: Some("128MB".to_string()),
1105                backend: Some("numpy".to_string()),
1106                overlap: Some(0),
1107                boundary: Some("reflect".to_string()),
1108                trim: Some(true),
1109                rechunk_threshold: Some(4.0),
1110            }),
1111            dataframe: Some(DaskDataFrameConfig {
1112                partition_size: Some("128MB".to_string()),
1113                shuffle_method: Some(DaskShuffleMethod::Tasks),
1114                query_planning: Some(false),
1115                backend: Some("pandas".to_string()),
1116                optimize_index: Some(true),
1117            }),
1118            bag: Some(DaskBagConfig {
1119                partition_size: Some(134217728), // 128MB
1120                compression: Some("gzip".to_string()),
1121                encoding: Some("utf-8".to_string()),
1122                linedelimiter: Some("\n".to_string()),
1123            }),
1124            ml: None,
1125            distributed: Some(DaskDistributedConfig {
1126                comm: Some(DaskCommConfig {
1127                    compression: Some("lz4".to_string()),
1128                    serializers: Some(vec!["pickle".to_string(), "msgpack".to_string()]),
1129                    timeouts: Some(DaskTimeoutsConfig {
1130                        connect: Some(10.0),
1131                        tcp: Some(30.0),
1132                    }),
1133                    tcp: Some(DaskTcpConfig {
1134                        reuse_port: Some(false),
1135                        no_delay: Some(true),
1136                        keep_alive: Some(false),
1137                    }),
1138                }),
1139                serialization: Some(DaskSerializationConfig {
1140                    compression: Some(vec!["lz4".to_string(), "zlib".to_string()]),
1141                    default_serializers: Some(vec!["pickle".to_string()]),
1142                    pickle_protocol: Some(4),
1143                }),
1144                client: Some(DaskClientConfig {
1145                    heartbeat_interval: Some(5.0),
1146                    scheduler_info_interval: Some(2.0),
1147                    task_metadata: Some(vec!["task_name".to_string(), "worker".to_string()]),
1148                    set_as_default: Some(true),
1149                }),
1150                scheduling: Some(DaskSchedulingConfig {
1151                    work_stealing: Some(true),
1152                    work_stealing_interval: Some(0.1),
1153                    unknown_task_duration: Some(0.5),
1154                    default_task_durations: None,
1155                }),
1156                diagnostics: Some(DaskDiagnosticsConfig {
1157                    progress_bar: Some(true),
1158                    profile: Some(false),
1159                    memory_profiling: Some(false),
1160                    task_stream: Some(false),
1161                    resource_monitor: Some(false),
1162                }),
1163            }),
1164        }
1165    }
1166
1167    /// Create a configuration for machine learning workloads
1168    pub fn config_with_ml() -> DaskConfig {
1169        let mut config = Self::default_config();
1170
1171        config.ml = Some(DaskMLConfig {
1172            model_selection: Some(DaskMLModelSelectionConfig {
1173                cv_folds: Some(5),
1174                scoring: Some("accuracy".to_string()),
1175                n_jobs: Some(-1),
1176                search_method: Some(DaskMLSearchMethod::RandomSearch),
1177            }),
1178            preprocessing: Some(DaskMLPreprocessingConfig {
1179                standardization: Some("StandardScaler".to_string()),
1180                categorical_encoding: Some("OneHotEncoder".to_string()),
1181                feature_selection: Some("SelectKBest".to_string()),
1182                dimensionality_reduction: Some("PCA".to_string()),
1183            }),
1184            linear_model: Some(DaskMLLinearModelConfig {
1185                solver: Some("lbfgs".to_string()),
1186                alpha: Some(1.0),
1187                max_iter: Some(1000),
1188                tol: Some(1e-4),
1189            }),
1190            ensemble: Some(DaskMLEnsembleConfig {
1191                n_estimators: Some(100),
1192                bootstrap: Some(true),
1193                random_state: Some(42),
1194                oob_score: Some(true),
1195            }),
1196            cluster: Some(DaskMLClusterConfig {
1197                n_clusters: Some(8),
1198                init: Some("k-means++".to_string()),
1199                max_iter: Some(300),
1200                tol: Some(1e-4),
1201                n_init: Some(10),
1202            }),
1203        });
1204
1205        config
1206    }
1207
1208    /// Create a configuration for large-scale distributed computing
1209    pub fn config_with_large_scale(n_workers: u32, memory_per_worker: &str) -> DaskConfig {
1210        let mut config = Self::default_config();
1211
1212        if let Some(ref mut cluster) = config.cluster {
1213            cluster.n_workers = Some(n_workers);
1214            cluster.memory_limit = Some(memory_per_worker.to_string());
1215            cluster.processes = Some(true); // Use processes for large scale
1216
1217            if let Some(ref mut scaling) = cluster.scaling {
1218                scaling.minimum = Some(n_workers / 2);
1219                scaling.maximum = Some(n_workers * 2);
1220            }
1221        }
1222
1223        if let Some(ref mut worker) = config.worker {
1224            worker.nworkers = Some(n_workers);
1225            worker.memory_limit = Some(memory_per_worker.to_string());
1226        }
1227
1228        // Optimize for large scale
1229        if let Some(ref mut distributed) = config.distributed {
1230            if let Some(ref mut comm) = distributed.comm {
1231                comm.compression = Some("zstd".to_string()); // Better compression for large data
1232            }
1233
1234            if let Some(ref mut scheduling) = distributed.scheduling {
1235                scheduling.work_stealing_interval = Some(0.5); // Less aggressive for stability
1236            }
1237        }
1238
1239        config
1240    }
1241}
1242
1243impl Default for DaskConfig {
1244    fn default() -> Self {
1245        DaskIntegration::default_config()
1246    }
1247}
1248
1249#[cfg(test)]
1250mod tests {
1251    use super::*;
1252
1253    #[test]
1254    fn test_dask_config_validation() {
1255        let config = DaskIntegration::default_config();
1256        let mut integration = DaskIntegration::new(config);
1257
1258        // Should succeed with valid parameters
1259        assert!(integration.initialize(0, 4, 0, 2).is_ok());
1260        assert!(integration.is_initialized());
1261        assert!(integration.is_client_active());
1262        assert_eq!(integration.rank(), 0);
1263        assert_eq!(integration.world_size(), 4);
1264        assert_eq!(integration.local_rank(), 0);
1265        assert_eq!(integration.stats().workers_connected, 4);
1266    }
1267
1268    #[test]
1269    fn test_dask_task_submission() {
1270        let config = DaskIntegration::default_config();
1271        let mut integration = DaskIntegration::new(config);
1272
1273        assert!(integration.initialize(0, 4, 0, 2).is_ok());
1274
1275        // Submit tasks
1276        let task_id1 = integration.submit_task("compute_gradient", 1024).unwrap();
1277        let task_id2 = integration.submit_task("update_parameters", 2048).unwrap();
1278
1279        assert!(task_id1.contains("compute_gradient"));
1280        assert!(task_id2.contains("update_parameters"));
1281
1282        let stats = integration.stats();
1283        assert_eq!(stats.tasks_executed, 2);
1284        assert!(stats.task_execution_time_sec > 0.0);
1285        assert_eq!(stats.data_transferred_bytes, 3072);
1286        assert!(stats.average_task_duration_sec > 0.0);
1287    }
1288
1289    #[test]
1290    fn test_dask_compute_collection() {
1291        let config = DaskIntegration::default_config();
1292        let mut integration = DaskIntegration::new(config);
1293
1294        assert!(integration.initialize(0, 4, 0, 2).is_ok());
1295
1296        // Compute collection (should submit multiple tasks internally)
1297        assert!(integration.compute("training_dataset").is_ok());
1298
1299        let stats = integration.stats();
1300        assert_eq!(stats.tasks_executed, 10); // Should create 10 tasks
1301        assert!(stats.task_execution_time_sec > 0.0);
1302    }
1303
1304    #[test]
1305    fn test_dask_cluster_scaling() {
1306        let config = DaskIntegration::default_config();
1307        let mut integration = DaskIntegration::new(config);
1308
1309        assert!(integration.initialize(0, 4, 0, 2).is_ok());
1310
1311        // Scale up
1312        assert!(integration.scale_cluster(8).is_ok());
1313        assert_eq!(integration.stats().workers_connected, 8);
1314
1315        // Scale down
1316        assert!(integration.scale_cluster(2).is_ok());
1317        assert_eq!(integration.stats().workers_connected, 2);
1318
1319        // Should fail scaling below minimum
1320        assert!(integration.scale_cluster(0).is_err());
1321
1322        // Should fail scaling above maximum
1323        assert!(integration.scale_cluster(20).is_err());
1324    }
1325
1326    #[test]
1327    fn test_dask_worker_failure_handling() {
1328        let config = DaskIntegration::default_config();
1329        let mut integration = DaskIntegration::new(config);
1330
1331        assert!(integration.initialize(0, 4, 0, 2).is_ok());
1332
1333        // Simulate worker failure
1334        assert!(integration.handle_worker_failure(1).is_ok());
1335        assert_eq!(integration.stats().worker_failures, 1);
1336        assert_eq!(integration.stats().workers_connected, 3);
1337
1338        // Should auto-scale back to minimum if configured
1339        assert!(integration.handle_worker_failure(2).is_ok());
1340        assert!(integration.handle_worker_failure(3).is_ok());
1341        assert_eq!(integration.stats().worker_failures, 3);
1342        assert_eq!(integration.stats().workers_connected, 1); // Auto-scaled to minimum
1343    }
1344
1345    #[test]
1346    fn test_dask_ml_config() {
1347        let config = DaskIntegration::config_with_ml();
1348        let mut integration = DaskIntegration::new(config);
1349
1350        assert!(integration.initialize(0, 4, 0, 2).is_ok());
1351
1352        // Check ML configuration
1353        assert!(integration.config().ml.is_some());
1354
1355        if let Some(ref ml) = integration.config().ml {
1356            assert!(ml.model_selection.is_some());
1357            assert!(ml.preprocessing.is_some());
1358            assert!(ml.linear_model.is_some());
1359            assert!(ml.ensemble.is_some());
1360            assert!(ml.cluster.is_some());
1361
1362            if let Some(ref model_selection) = ml.model_selection {
1363                assert_eq!(model_selection.cv_folds, Some(5));
1364                assert_eq!(
1365                    model_selection.search_method,
1366                    Some(DaskMLSearchMethod::RandomSearch)
1367                );
1368            }
1369        }
1370    }
1371
1372    #[test]
1373    fn test_dask_large_scale_config() {
1374        let config = DaskIntegration::config_with_large_scale(16, "8GB");
1375        let mut integration = DaskIntegration::new(config);
1376
1377        assert!(integration.initialize(0, 16, 0, 4).is_ok());
1378
1379        // Check large scale configuration
1380        if let Some(ref cluster) = integration.config().cluster {
1381            assert_eq!(cluster.n_workers, Some(16));
1382            assert_eq!(cluster.memory_limit, Some("8GB".to_string()));
1383            assert_eq!(cluster.processes, Some(true));
1384
1385            if let Some(ref scaling) = cluster.scaling {
1386                assert_eq!(scaling.minimum, Some(8));
1387                assert_eq!(scaling.maximum, Some(32));
1388            }
1389        }
1390    }
1391
1392    #[test]
1393    fn test_dask_shutdown() {
1394        let config = DaskIntegration::default_config();
1395        let mut integration = DaskIntegration::new(config);
1396
1397        assert!(integration.initialize(0, 4, 0, 2).is_ok());
1398        assert!(integration.is_client_active());
1399
1400        assert!(integration.shutdown().is_ok());
1401        assert!(!integration.is_client_active());
1402        assert!(!integration.is_initialized());
1403        assert_eq!(integration.stats().workers_connected, 0);
1404    }
1405
1406    #[test]
1407    fn test_dask_config_serialization() {
1408        let config = DaskIntegration::config_with_ml();
1409
1410        // Test JSON serialization
1411        let json = serde_json::to_string(&config).unwrap();
1412        assert!(json.contains("Local"));
1413        assert!(json.contains("accuracy"));
1414        assert!(json.contains("RandomSearch"));
1415
1416        // Test deserialization
1417        let deserialized: DaskConfig = serde_json::from_str(&json).unwrap();
1418        assert!(deserialized.cluster.is_some());
1419        assert!(deserialized.ml.is_some());
1420
1421        if let Some(cluster) = deserialized.cluster {
1422            assert_eq!(cluster.cluster_type, DaskClusterType::Local);
1423        }
1424    }
1425
1426    #[test]
1427    fn test_dask_invalid_config() {
1428        let mut config = DaskIntegration::default_config();
1429
1430        // Make configuration invalid
1431        if let Some(ref mut cluster) = config.cluster {
1432            cluster.n_workers = Some(0); // Invalid: 0 workers
1433        }
1434
1435        let mut integration = DaskIntegration::new(config);
1436
1437        // Should fail validation
1438        assert!(integration.initialize(0, 4, 0, 2).is_err());
1439    }
1440
1441    #[test]
1442    fn test_dask_scaling_validation() {
1443        let mut config = DaskIntegration::default_config();
1444
1445        // Set invalid scaling config
1446        if let Some(ref mut cluster) = config.cluster {
1447            if let Some(ref mut scaling) = cluster.scaling {
1448                scaling.minimum = Some(10);
1449                scaling.maximum = Some(5); // Invalid: min > max
1450            }
1451        }
1452
1453        let mut integration = DaskIntegration::new(config);
1454
1455        // Should fail validation
1456        assert!(integration.initialize(0, 4, 0, 2).is_err());
1457    }
1458}