sklears_inspection/
distributed.rs

1//! Distributed computation infrastructure for enterprise-scale explanation tasks
2//!
3//! This module provides distributed computing capabilities for scaling explanation
4//! computation across multiple nodes in a cluster. It includes:
5//!
6//! * Distributed task scheduling and execution
7//! * Load balancing across compute nodes
8//! * Fault tolerance and retry mechanisms
9//! * Result aggregation from multiple workers
10//! * Cluster health monitoring
11//! * Dynamic worker scaling
12//!
13//! # Architecture
14//!
15//! The distributed infrastructure follows a coordinator-worker architecture:
16//! - **Coordinator**: Manages task distribution, load balancing, and result aggregation
17//! - **Workers**: Execute explanation computation tasks independently
18//! - **Task Queue**: Manages pending tasks with priority scheduling
19//! - **Result Store**: Aggregates and persists computation results
20//!
21//! # Example
22//!
23//! ```rust
24//! use sklears_inspection::distributed::{DistributedCoordinator, WorkerNode, ClusterConfig};
25//!
26//! // Create a cluster configuration
27//! let config = ClusterConfig {
28//!     max_workers: 10,
29//!     task_timeout_seconds: 300,
30//!     retry_attempts: 3,
31//!     load_balancing_strategy: LoadBalancingStrategy::LeastLoaded,
32//!     enable_fault_tolerance: true,
33//! };
34//!
35//! // Initialize coordinator
36//! let coordinator = DistributedCoordinator::new(config)?;
37//!
38//! // Register workers
39//! coordinator.register_worker("worker1".to_string(), "192.168.1.10:8080")?;
40//! coordinator.register_worker("worker2".to_string(), "192.168.1.11:8080")?;
41//!
42//! // Submit explanation tasks
43//! let task_id = coordinator.submit_task(explanation_task)?;
44//!
45//! // Wait for results
46//! let result = coordinator.get_result(task_id)?;
47//! # Ok::<(), Box<dyn std::error::Error>>(())
48//! ```
49
50use crate::types::Float;
51use scirs2_core::ndarray::{Array1, Array2};
52use sklears_core::error::{Result as SklResult, SklearsError};
53use std::collections::{HashMap, VecDeque};
54use std::sync::{Arc, Mutex};
55use std::time::{Duration, Instant};
56
57/// Configuration for distributed cluster
58#[derive(Debug, Clone)]
59pub struct ClusterConfig {
60    /// Maximum number of workers in the cluster
61    pub max_workers: usize,
62    /// Task execution timeout in seconds
63    pub task_timeout_seconds: u64,
64    /// Number of retry attempts for failed tasks
65    pub retry_attempts: usize,
66    /// Load balancing strategy
67    pub load_balancing_strategy: LoadBalancingStrategy,
68    /// Enable fault tolerance
69    pub enable_fault_tolerance: bool,
70    /// Heartbeat interval in seconds
71    pub heartbeat_interval_seconds: u64,
72    /// Maximum task queue size
73    pub max_queue_size: usize,
74    /// Enable dynamic worker scaling
75    pub enable_auto_scaling: bool,
76    /// Target CPU utilization for auto-scaling (0.0-1.0)
77    pub target_cpu_utilization: f64,
78}
79
80impl Default for ClusterConfig {
81    fn default() -> Self {
82        Self {
83            max_workers: 10,
84            task_timeout_seconds: 300,
85            retry_attempts: 3,
86            load_balancing_strategy: LoadBalancingStrategy::RoundRobin,
87            enable_fault_tolerance: true,
88            heartbeat_interval_seconds: 30,
89            max_queue_size: 1000,
90            enable_auto_scaling: false,
91            target_cpu_utilization: 0.7,
92        }
93    }
94}
95
96/// Load balancing strategy for task distribution
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum LoadBalancingStrategy {
99    /// Round-robin distribution
100    RoundRobin,
101    /// Assign to least loaded worker
102    LeastLoaded,
103    /// Weighted distribution based on worker capacity
104    Weighted,
105    /// Random assignment
106    Random,
107    /// Locality-aware (prefer workers with cached data)
108    LocalityAware,
109}
110
111/// Distributed task coordinator
112pub struct DistributedCoordinator {
113    /// Configuration
114    config: ClusterConfig,
115    /// Registered workers
116    workers: Arc<Mutex<HashMap<String, WorkerNode>>>,
117    /// Task queue
118    task_queue: Arc<Mutex<VecDeque<DistributedTask>>>,
119    /// Completed results
120    results: Arc<Mutex<HashMap<String, TaskResult>>>,
121    /// Task assignments
122    assignments: Arc<Mutex<HashMap<String, String>>>, // task_id -> worker_id
123    /// Round-robin counter for load balancing
124    round_robin_counter: Arc<Mutex<usize>>,
125    /// Cluster statistics
126    statistics: Arc<Mutex<ClusterStatistics>>,
127}
128
129impl DistributedCoordinator {
130    /// Create a new distributed coordinator
131    pub fn new(config: ClusterConfig) -> SklResult<Self> {
132        Ok(Self {
133            config,
134            workers: Arc::new(Mutex::new(HashMap::new())),
135            task_queue: Arc::new(Mutex::new(VecDeque::new())),
136            results: Arc::new(Mutex::new(HashMap::new())),
137            assignments: Arc::new(Mutex::new(HashMap::new())),
138            round_robin_counter: Arc::new(Mutex::new(0)),
139            statistics: Arc::new(Mutex::new(ClusterStatistics::new())),
140        })
141    }
142
143    /// Register a worker node
144    pub fn register_worker(&self, worker_id: String, address: String) -> SklResult<()> {
145        let mut workers = self.workers.lock().map_err(|_| {
146            SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
147        })?;
148
149        if workers.len() >= self.config.max_workers {
150            return Err(SklearsError::InvalidInput(
151                "Maximum number of workers reached".to_string(),
152            ));
153        }
154
155        let worker = WorkerNode::new(worker_id.clone(), address);
156        workers.insert(worker_id.clone(), worker);
157
158        let mut stats = self.statistics.lock().map_err(|_| {
159            SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
160        })?;
161        stats.active_workers += 1;
162
163        Ok(())
164    }
165
166    /// Unregister a worker node
167    pub fn unregister_worker(&self, worker_id: &str) -> SklResult<()> {
168        let mut workers = self.workers.lock().map_err(|_| {
169            SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
170        })?;
171
172        if workers.remove(worker_id).is_some() {
173            let mut stats = self.statistics.lock().map_err(|_| {
174                SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
175            })?;
176            stats.active_workers = stats.active_workers.saturating_sub(1);
177            Ok(())
178        } else {
179            Err(SklearsError::InvalidInput(format!(
180                "Worker {} not found",
181                worker_id
182            )))
183        }
184    }
185
186    /// Submit a task for distributed execution
187    pub fn submit_task(&self, task: DistributedTask) -> SklResult<String> {
188        let mut queue = self
189            .task_queue
190            .lock()
191            .map_err(|_| SklearsError::InvalidInput("Failed to acquire queue lock".to_string()))?;
192
193        if queue.len() >= self.config.max_queue_size {
194            return Err(SklearsError::InvalidInput("Task queue is full".to_string()));
195        }
196
197        let task_id = task.task_id.clone();
198        queue.push_back(task);
199
200        let mut stats = self.statistics.lock().map_err(|_| {
201            SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
202        })?;
203        stats.total_tasks_submitted += 1;
204        stats.pending_tasks += 1;
205
206        Ok(task_id)
207    }
208
209    /// Process pending tasks and assign to workers
210    pub fn schedule_tasks(&self) -> SklResult<usize> {
211        let mut scheduled = 0;
212
213        loop {
214            // Get next task from queue
215            let task = {
216                let mut queue = self.task_queue.lock().map_err(|_| {
217                    SklearsError::InvalidInput("Failed to acquire queue lock".to_string())
218                })?;
219                queue.pop_front()
220            };
221
222            match task {
223                None => break, // No more tasks
224                Some(task) => {
225                    // Select worker based on load balancing strategy
226                    let worker_id = self.select_worker(&task)?;
227
228                    // Assign task to worker
229                    self.assign_task_to_worker(task, &worker_id)?;
230
231                    scheduled += 1;
232                }
233            }
234        }
235
236        Ok(scheduled)
237    }
238
239    /// Select a worker based on load balancing strategy
240    fn select_worker(&self, task: &DistributedTask) -> SklResult<String> {
241        let workers = self.workers.lock().map_err(|_| {
242            SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
243        })?;
244
245        if workers.is_empty() {
246            return Err(SklearsError::InvalidInput(
247                "No workers available".to_string(),
248            ));
249        }
250
251        match self.config.load_balancing_strategy {
252            LoadBalancingStrategy::RoundRobin => {
253                let mut counter = self.round_robin_counter.lock().map_err(|_| {
254                    SklearsError::InvalidInput("Failed to acquire counter lock".to_string())
255                })?;
256
257                let worker_ids: Vec<String> = workers.keys().cloned().collect();
258                let selected = &worker_ids[*counter % worker_ids.len()];
259                *counter += 1;
260
261                Ok(selected.clone())
262            }
263            LoadBalancingStrategy::LeastLoaded => {
264                let mut least_loaded_worker = None;
265                let mut min_load = usize::MAX;
266
267                for (worker_id, worker) in workers.iter() {
268                    if worker.current_load < min_load {
269                        min_load = worker.current_load;
270                        least_loaded_worker = Some(worker_id.clone());
271                    }
272                }
273
274                least_loaded_worker.ok_or_else(|| {
275                    SklearsError::InvalidInput("Failed to find least loaded worker".to_string())
276                })
277            }
278            LoadBalancingStrategy::Weighted => {
279                // Simplified: use worker capacity as weight
280                let mut best_worker = None;
281                let mut best_score = 0.0;
282
283                for (worker_id, worker) in workers.iter() {
284                    let score = (worker.capacity as f64) / (worker.current_load as f64 + 1.0);
285                    if score > best_score {
286                        best_score = score;
287                        best_worker = Some(worker_id.clone());
288                    }
289                }
290
291                best_worker.ok_or_else(|| {
292                    SklearsError::InvalidInput("Failed to find weighted worker".to_string())
293                })
294            }
295            LoadBalancingStrategy::Random => {
296                use scirs2_core::random::{thread_rng, CoreRandom};
297
298                let worker_ids: Vec<String> = workers.keys().cloned().collect();
299                let mut rng = thread_rng();
300                let index = rng.gen_range(0..worker_ids.len());
301                Ok(worker_ids[index].clone())
302            }
303            LoadBalancingStrategy::LocalityAware => {
304                // Simplified: prefer worker with lowest network latency
305                let mut best_worker = None;
306                let mut best_latency = Duration::from_secs(u64::MAX);
307
308                for (worker_id, worker) in workers.iter() {
309                    if worker.network_latency < best_latency {
310                        best_latency = worker.network_latency;
311                        best_worker = Some(worker_id.clone());
312                    }
313                }
314
315                best_worker.ok_or_else(|| {
316                    SklearsError::InvalidInput("Failed to find locality-aware worker".to_string())
317                })
318            }
319        }
320    }
321
322    /// Assign task to a specific worker
323    fn assign_task_to_worker(&self, task: DistributedTask, worker_id: &str) -> SklResult<()> {
324        // Update worker state and record assignment (in separate scope to release locks)
325        {
326            let mut workers = self.workers.lock().map_err(|_| {
327                SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
328            })?;
329
330            let worker = workers.get_mut(worker_id).ok_or_else(|| {
331                SklearsError::InvalidInput(format!("Worker {} not found", worker_id))
332            })?;
333
334            // Update worker state
335            worker.current_load += 1;
336            worker.total_tasks_processed += 1;
337        } // Release workers lock here
338
339        // Record assignment
340        {
341            let mut assignments = self.assignments.lock().map_err(|_| {
342                SklearsError::InvalidInput("Failed to acquire assignments lock".to_string())
343            })?;
344            assignments.insert(task.task_id.clone(), worker_id.to_string());
345        } // Release assignments lock here
346
347        // Update statistics
348        {
349            let mut stats = self.statistics.lock().map_err(|_| {
350                SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
351            })?;
352            stats.pending_tasks = stats.pending_tasks.saturating_sub(1);
353            stats.running_tasks += 1;
354        } // Release stats lock here
355
356        // In a real implementation, this would send the task to the worker
357        // For now, we simulate immediate processing (all locks released before this call)
358        self.simulate_task_execution(task, worker_id.to_string())?;
359
360        Ok(())
361    }
362
363    /// Simulate task execution (in real implementation, this would be async)
364    fn simulate_task_execution(&self, task: DistributedTask, worker_id: String) -> SklResult<()> {
365        // Simulate task processing
366        let result = TaskResult {
367            task_id: task.task_id.clone(),
368            worker_id: worker_id.clone(),
369            status: TaskStatus::Completed,
370            result_data: Array1::zeros(10), // Placeholder
371            execution_time: Duration::from_millis(100),
372            retry_count: 0,
373        };
374
375        // Store result
376        let mut results = self.results.lock().map_err(|_| {
377            SklearsError::InvalidInput("Failed to acquire results lock".to_string())
378        })?;
379        results.insert(task.task_id.clone(), result);
380
381        // Update worker load
382        let mut workers = self.workers.lock().map_err(|_| {
383            SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
384        })?;
385        if let Some(worker) = workers.get_mut(&worker_id) {
386            worker.current_load = worker.current_load.saturating_sub(1);
387        }
388
389        // Update statistics
390        let mut stats = self.statistics.lock().map_err(|_| {
391            SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
392        })?;
393        stats.running_tasks = stats.running_tasks.saturating_sub(1);
394        stats.completed_tasks += 1;
395
396        Ok(())
397    }
398
399    /// Get result for a completed task
400    pub fn get_result(&self, task_id: &str) -> SklResult<TaskResult> {
401        let results = self.results.lock().map_err(|_| {
402            SklearsError::InvalidInput("Failed to acquire results lock".to_string())
403        })?;
404
405        results.get(task_id).cloned().ok_or_else(|| {
406            SklearsError::InvalidInput(format!("Result for task {} not found", task_id))
407        })
408    }
409
410    /// Get cluster statistics
411    pub fn get_statistics(&self) -> SklResult<ClusterStatistics> {
412        let stats = self.statistics.lock().map_err(|_| {
413            SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
414        })?;
415
416        Ok(stats.clone())
417    }
418
419    /// Get worker information
420    pub fn get_worker_info(&self, worker_id: &str) -> SklResult<WorkerNode> {
421        let workers = self.workers.lock().map_err(|_| {
422            SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
423        })?;
424
425        workers
426            .get(worker_id)
427            .cloned()
428            .ok_or_else(|| SklearsError::InvalidInput(format!("Worker {} not found", worker_id)))
429    }
430
431    /// Get all workers
432    pub fn get_all_workers(&self) -> SklResult<Vec<WorkerNode>> {
433        let workers = self.workers.lock().map_err(|_| {
434            SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
435        })?;
436
437        Ok(workers.values().cloned().collect())
438    }
439
440    /// Health check for the cluster
441    pub fn health_check(&self) -> SklResult<ClusterHealth> {
442        let workers = self.workers.lock().map_err(|_| {
443            SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
444        })?;
445
446        let stats = self.statistics.lock().map_err(|_| {
447            SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
448        })?;
449
450        let total_workers = workers.len();
451        let healthy_workers = workers.values().filter(|w| w.is_healthy).count();
452
453        let health_status = if healthy_workers == 0 {
454            HealthStatus::Critical
455        } else if healthy_workers < total_workers / 2 {
456            HealthStatus::Degraded
457        } else if healthy_workers < total_workers {
458            HealthStatus::Warning
459        } else {
460            HealthStatus::Healthy
461        };
462
463        Ok(ClusterHealth {
464            status: health_status,
465            total_workers,
466            healthy_workers,
467            total_capacity: workers.values().map(|w| w.capacity).sum(),
468            current_load: workers.values().map(|w| w.current_load).sum(),
469            pending_tasks: stats.pending_tasks,
470            running_tasks: stats.running_tasks,
471        })
472    }
473}
474
475/// Worker node in the cluster
476#[derive(Debug, Clone)]
477pub struct WorkerNode {
478    /// Worker identifier
479    pub worker_id: String,
480    /// Network address
481    pub address: String,
482    /// Worker capacity (max concurrent tasks)
483    pub capacity: usize,
484    /// Current load (number of active tasks)
485    pub current_load: usize,
486    /// Total tasks processed
487    pub total_tasks_processed: usize,
488    /// Worker health status
489    pub is_healthy: bool,
490    /// Last heartbeat time
491    pub last_heartbeat: Instant,
492    /// Network latency
493    pub network_latency: Duration,
494    /// CPU utilization (0.0-1.0)
495    pub cpu_utilization: f64,
496    /// Memory utilization (0.0-1.0)
497    pub memory_utilization: f64,
498}
499
500impl WorkerNode {
501    /// Create a new worker node
502    pub fn new(worker_id: String, address: String) -> Self {
503        Self {
504            worker_id,
505            address,
506            capacity: 10,
507            current_load: 0,
508            total_tasks_processed: 0,
509            is_healthy: true,
510            last_heartbeat: Instant::now(),
511            network_latency: Duration::from_millis(10),
512            cpu_utilization: 0.0,
513            memory_utilization: 0.0,
514        }
515    }
516
517    /// Update heartbeat
518    pub fn heartbeat(&mut self) {
519        self.last_heartbeat = Instant::now();
520        self.is_healthy = true;
521    }
522
523    /// Check if worker is overloaded
524    pub fn is_overloaded(&self) -> bool {
525        self.current_load >= self.capacity
526    }
527
528    /// Get available capacity
529    pub fn available_capacity(&self) -> usize {
530        self.capacity.saturating_sub(self.current_load)
531    }
532}
533
534/// Distributed task
535#[derive(Debug, Clone)]
536pub struct DistributedTask {
537    /// Task identifier
538    pub task_id: String,
539    /// Task type
540    pub task_type: TaskType,
541    /// Task priority (higher = more important)
542    pub priority: usize,
543    /// Input data
544    pub input_data: Array2<Float>,
545    /// Task metadata
546    pub metadata: HashMap<String, String>,
547    /// Creation time
548    pub created_at: Instant,
549}
550
551/// Type of distributed task
552#[derive(Debug, Clone, PartialEq, Eq)]
553pub enum TaskType {
554    /// Compute SHAP values
555    ComputeShap,
556    /// Compute permutation importance
557    ComputePermutationImportance,
558    /// Generate counterfactual explanations
559    GenerateCounterfactuals,
560    /// Compute feature importance
561    ComputeFeatureImportance,
562    /// Batch explanation generation
563    BatchExplanation,
564}
565
566/// Task execution result
567#[derive(Debug, Clone)]
568pub struct TaskResult {
569    /// Task identifier
570    pub task_id: String,
571    /// Worker that processed the task
572    pub worker_id: String,
573    /// Task status
574    pub status: TaskStatus,
575    /// Result data
576    pub result_data: Array1<Float>,
577    /// Execution time
578    pub execution_time: Duration,
579    /// Number of retries
580    pub retry_count: usize,
581}
582
583/// Task execution status
584#[derive(Debug, Clone, Copy, PartialEq, Eq)]
585pub enum TaskStatus {
586    /// Task is pending
587    Pending,
588    /// Task is running
589    Running,
590    /// Task completed successfully
591    Completed,
592    /// Task failed
593    Failed,
594    /// Task was cancelled
595    Cancelled,
596}
597
598/// Cluster statistics
599#[derive(Debug, Clone)]
600pub struct ClusterStatistics {
601    /// Number of active workers
602    pub active_workers: usize,
603    /// Total tasks submitted
604    pub total_tasks_submitted: usize,
605    /// Pending tasks
606    pub pending_tasks: usize,
607    /// Running tasks
608    pub running_tasks: usize,
609    /// Completed tasks
610    pub completed_tasks: usize,
611    /// Failed tasks
612    pub failed_tasks: usize,
613    /// Average task execution time
614    pub avg_execution_time: Duration,
615    /// Total data processed (in bytes)
616    pub total_data_processed: usize,
617}
618
619impl ClusterStatistics {
620    fn new() -> Self {
621        Self {
622            active_workers: 0,
623            total_tasks_submitted: 0,
624            pending_tasks: 0,
625            running_tasks: 0,
626            completed_tasks: 0,
627            failed_tasks: 0,
628            avg_execution_time: Duration::from_secs(0),
629            total_data_processed: 0,
630        }
631    }
632}
633
634/// Cluster health information
635#[derive(Debug, Clone)]
636pub struct ClusterHealth {
637    /// Overall health status
638    pub status: HealthStatus,
639    /// Total number of workers
640    pub total_workers: usize,
641    /// Number of healthy workers
642    pub healthy_workers: usize,
643    /// Total cluster capacity
644    pub total_capacity: usize,
645    /// Current cluster load
646    pub current_load: usize,
647    /// Number of pending tasks
648    pub pending_tasks: usize,
649    /// Number of running tasks
650    pub running_tasks: usize,
651}
652
653/// Health status of the cluster
654#[derive(Debug, Clone, Copy, PartialEq, Eq)]
655pub enum HealthStatus {
656    /// All workers healthy
657    Healthy,
658    /// Some workers degraded
659    Warning,
660    /// Many workers unhealthy
661    Degraded,
662    /// Critical failure
663    Critical,
664}
665
666#[cfg(test)]
667mod tests {
668    use super::*;
669
670    #[test]
671    fn test_cluster_config_default() {
672        let config = ClusterConfig::default();
673        assert_eq!(config.max_workers, 10);
674        assert_eq!(config.task_timeout_seconds, 300);
675        assert_eq!(config.retry_attempts, 3);
676        assert_eq!(
677            config.load_balancing_strategy,
678            LoadBalancingStrategy::RoundRobin
679        );
680        assert!(config.enable_fault_tolerance);
681    }
682
683    #[test]
684    fn test_distributed_coordinator_creation() {
685        let config = ClusterConfig::default();
686        let coordinator = DistributedCoordinator::new(config);
687        assert!(coordinator.is_ok());
688    }
689
690    #[test]
691    fn test_register_worker() {
692        let config = ClusterConfig::default();
693        let coordinator = DistributedCoordinator::new(config).unwrap();
694
695        let result =
696            coordinator.register_worker("worker1".to_string(), "192.168.1.10:8080".to_string());
697        assert!(result.is_ok());
698
699        let stats = coordinator.get_statistics().unwrap();
700        assert_eq!(stats.active_workers, 1);
701    }
702
703    #[test]
704    fn test_register_multiple_workers() {
705        let config = ClusterConfig::default();
706        let coordinator = DistributedCoordinator::new(config).unwrap();
707
708        coordinator
709            .register_worker("worker1".to_string(), "192.168.1.10:8080".to_string())
710            .unwrap();
711        coordinator
712            .register_worker("worker2".to_string(), "192.168.1.11:8080".to_string())
713            .unwrap();
714
715        let stats = coordinator.get_statistics().unwrap();
716        assert_eq!(stats.active_workers, 2);
717    }
718
719    #[test]
720    fn test_register_worker_limit() {
721        let config = ClusterConfig {
722            max_workers: 2,
723            ..Default::default()
724        };
725        let coordinator = DistributedCoordinator::new(config).unwrap();
726
727        coordinator
728            .register_worker("worker1".to_string(), "192.168.1.10:8080".to_string())
729            .unwrap();
730        coordinator
731            .register_worker("worker2".to_string(), "192.168.1.11:8080".to_string())
732            .unwrap();
733
734        let result =
735            coordinator.register_worker("worker3".to_string(), "192.168.1.12:8080".to_string());
736        assert!(result.is_err());
737    }
738
739    #[test]
740    fn test_unregister_worker() {
741        let config = ClusterConfig::default();
742        let coordinator = DistributedCoordinator::new(config).unwrap();
743
744        coordinator
745            .register_worker("worker1".to_string(), "192.168.1.10:8080".to_string())
746            .unwrap();
747
748        let result = coordinator.unregister_worker("worker1");
749        assert!(result.is_ok());
750
751        let stats = coordinator.get_statistics().unwrap();
752        assert_eq!(stats.active_workers, 0);
753    }
754
755    #[test]
756    fn test_submit_task() {
757        let config = ClusterConfig::default();
758        let coordinator = DistributedCoordinator::new(config).unwrap();
759
760        let task = DistributedTask {
761            task_id: "task1".to_string(),
762            task_type: TaskType::ComputeShap,
763            priority: 1,
764            input_data: Array2::zeros((10, 5)),
765            metadata: HashMap::new(),
766            created_at: Instant::now(),
767        };
768
769        let result = coordinator.submit_task(task);
770        assert!(result.is_ok());
771        assert_eq!(result.unwrap(), "task1");
772
773        let stats = coordinator.get_statistics().unwrap();
774        assert_eq!(stats.total_tasks_submitted, 1);
775        assert_eq!(stats.pending_tasks, 1);
776    }
777
778    #[test]
779    fn test_schedule_tasks() {
780        let config = ClusterConfig::default();
781        let coordinator = DistributedCoordinator::new(config).unwrap();
782
783        // Register a worker
784        coordinator
785            .register_worker("worker1".to_string(), "192.168.1.10:8080".to_string())
786            .unwrap();
787
788        // Submit a task
789        let task = DistributedTask {
790            task_id: "task1".to_string(),
791            task_type: TaskType::ComputeShap,
792            priority: 1,
793            input_data: Array2::zeros((10, 5)),
794            metadata: HashMap::new(),
795            created_at: Instant::now(),
796        };
797        coordinator.submit_task(task).unwrap();
798
799        // Schedule tasks
800        let scheduled = coordinator.schedule_tasks().unwrap();
801        assert_eq!(scheduled, 1);
802    }
803
804    #[test]
805    fn test_worker_node_creation() {
806        let worker = WorkerNode::new("worker1".to_string(), "192.168.1.10:8080".to_string());
807        assert_eq!(worker.worker_id, "worker1");
808        assert_eq!(worker.address, "192.168.1.10:8080");
809        assert_eq!(worker.capacity, 10);
810        assert_eq!(worker.current_load, 0);
811        assert!(worker.is_healthy);
812    }
813
814    #[test]
815    fn test_worker_node_overload() {
816        let mut worker = WorkerNode::new("worker1".to_string(), "192.168.1.10:8080".to_string());
817        worker.capacity = 5;
818        worker.current_load = 3;
819
820        assert!(!worker.is_overloaded());
821
822        worker.current_load = 5;
823        assert!(worker.is_overloaded());
824
825        worker.current_load = 6;
826        assert!(worker.is_overloaded());
827    }
828
829    #[test]
830    fn test_worker_node_available_capacity() {
831        let mut worker = WorkerNode::new("worker1".to_string(), "192.168.1.10:8080".to_string());
832        worker.capacity = 10;
833        worker.current_load = 3;
834
835        assert_eq!(worker.available_capacity(), 7);
836    }
837
838    #[test]
839    fn test_cluster_health_check() {
840        let config = ClusterConfig::default();
841        let coordinator = DistributedCoordinator::new(config).unwrap();
842
843        coordinator
844            .register_worker("worker1".to_string(), "192.168.1.10:8080".to_string())
845            .unwrap();
846        coordinator
847            .register_worker("worker2".to_string(), "192.168.1.11:8080".to_string())
848            .unwrap();
849
850        let health = coordinator.health_check().unwrap();
851        assert_eq!(health.status, HealthStatus::Healthy);
852        assert_eq!(health.total_workers, 2);
853        assert_eq!(health.healthy_workers, 2);
854    }
855
856    #[test]
857    fn test_load_balancing_strategies() {
858        // Test that all strategy variants can be created and are unique
859        assert_ne!(
860            LoadBalancingStrategy::RoundRobin,
861            LoadBalancingStrategy::LeastLoaded
862        );
863        assert_ne!(
864            LoadBalancingStrategy::RoundRobin,
865            LoadBalancingStrategy::Weighted
866        );
867        assert_ne!(
868            LoadBalancingStrategy::LeastLoaded,
869            LoadBalancingStrategy::Weighted
870        );
871        assert_ne!(
872            LoadBalancingStrategy::Random,
873            LoadBalancingStrategy::LocalityAware
874        );
875
876        // Test equality
877        assert_eq!(
878            LoadBalancingStrategy::RoundRobin,
879            LoadBalancingStrategy::RoundRobin
880        );
881    }
882}
883
884/// Cluster-based explanation computation orchestrator
885pub struct ClusterExplanationOrchestrator {
886    /// Distributed coordinator
887    coordinator: Arc<DistributedCoordinator>,
888    /// Configuration
889    config: ClusterConfig,
890    /// Explanation cache
891    cache: Arc<Mutex<HashMap<String, CachedExplanation>>>,
892    /// Active batch computations
893    active_batches: Arc<Mutex<HashMap<String, BatchComputation>>>,
894}
895
896impl ClusterExplanationOrchestrator {
897    /// Create a new cluster explanation orchestrator
898    pub fn new(config: ClusterConfig) -> SklResult<Self> {
899        let coordinator = Arc::new(DistributedCoordinator::new(config.clone())?);
900
901        Ok(Self {
902            coordinator,
903            config,
904            cache: Arc::new(Mutex::new(HashMap::new())),
905            active_batches: Arc::new(Mutex::new(HashMap::new())),
906        })
907    }
908
909    /// Register workers from a configuration
910    pub fn register_workers_from_config(&self, worker_configs: Vec<WorkerConfig>) -> SklResult<()> {
911        for worker_config in worker_configs {
912            self.coordinator
913                .register_worker(worker_config.worker_id, worker_config.address)?;
914        }
915        Ok(())
916    }
917
918    /// Compute SHAP values across the cluster
919    pub fn compute_shap_distributed(
920        &self,
921        data: &Array2<Float>,
922        background_data: &Array2<Float>,
923        batch_size: usize,
924    ) -> SklResult<Array2<Float>> {
925        let n_samples = data.nrows();
926        let n_features = data.ncols();
927
928        // Create batch ID
929        let batch_id = format!("shap_batch_{}", uuid::Uuid::new_v4());
930
931        // Split data into batches
932        let batches = self.split_into_batches(data, batch_size)?;
933
934        // Submit tasks for each batch
935        let mut task_ids = Vec::new();
936        for (batch_idx, batch) in batches.iter().enumerate() {
937            let task = DistributedTask {
938                task_id: format!("{}_task_{}", batch_id, batch_idx),
939                task_type: TaskType::ComputeShap,
940                priority: 1,
941                input_data: batch.clone(),
942                metadata: {
943                    let mut meta = HashMap::new();
944                    meta.insert("batch_id".to_string(), batch_id.clone());
945                    meta.insert("batch_idx".to_string(), batch_idx.to_string());
946                    meta
947                },
948                created_at: Instant::now(),
949            };
950
951            let task_id = self.coordinator.submit_task(task)?;
952            task_ids.push(task_id);
953        }
954
955        // Schedule all tasks
956        self.coordinator.schedule_tasks()?;
957
958        // Collect results
959        let mut all_results = Vec::new();
960        for task_id in task_ids {
961            let result = self.coordinator.get_result(&task_id)?;
962            all_results.push(result.result_data);
963        }
964
965        // Aggregate results
966        let aggregated = self.aggregate_shap_results(all_results, n_samples, n_features)?;
967
968        Ok(aggregated)
969    }
970
971    /// Compute feature importance across the cluster
972    pub fn compute_feature_importance_distributed(
973        &self,
974        data: &Array2<Float>,
975        predictions: &Array1<Float>,
976        batch_size: usize,
977    ) -> SklResult<Array1<Float>> {
978        let n_samples = data.nrows();
979        let n_features = data.ncols();
980
981        // Create batch ID
982        let batch_id = format!("importance_batch_{}", uuid::Uuid::new_v4());
983
984        // Split data into batches
985        let data_batches = self.split_into_batches(data, batch_size)?;
986        let pred_batches = self.split_predictions(predictions, batch_size)?;
987
988        // Submit tasks for each batch
989        let mut task_ids = Vec::new();
990        for (batch_idx, (data_batch, pred_batch)) in
991            data_batches.iter().zip(pred_batches.iter()).enumerate()
992        {
993            let task = DistributedTask {
994                task_id: format!("{}_task_{}", batch_id, batch_idx),
995                task_type: TaskType::ComputeFeatureImportance,
996                priority: 1,
997                input_data: data_batch.clone(),
998                metadata: {
999                    let mut meta = HashMap::new();
1000                    meta.insert("batch_id".to_string(), batch_id.clone());
1001                    meta.insert("batch_idx".to_string(), batch_idx.to_string());
1002                    meta
1003                },
1004                created_at: Instant::now(),
1005            };
1006
1007            let task_id = self.coordinator.submit_task(task)?;
1008            task_ids.push(task_id);
1009        }
1010
1011        // Schedule all tasks
1012        self.coordinator.schedule_tasks()?;
1013
1014        // Collect and average results
1015        let mut importance_sum = Array1::zeros(n_features);
1016        let mut count = 0;
1017
1018        for task_id in task_ids {
1019            let result = self.coordinator.get_result(&task_id)?;
1020            importance_sum += &result.result_data.slice(s![..n_features]).to_owned();
1021            count += 1;
1022        }
1023
1024        // Average importance across batches
1025        Ok(importance_sum / (count as Float))
1026    }
1027
1028    /// Generate counterfactuals across the cluster
1029    pub fn generate_counterfactuals_distributed(
1030        &self,
1031        instances: &Array2<Float>,
1032        target_class: usize,
1033        n_counterfactuals_per_instance: usize,
1034    ) -> SklResult<Vec<Array1<Float>>> {
1035        let batch_id = format!("counterfactual_batch_{}", uuid::Uuid::new_v4());
1036
1037        // Submit tasks for each instance
1038        let mut task_ids = Vec::new();
1039        for (instance_idx, instance) in instances.axis_iter(Axis(0)).enumerate() {
1040            let task = DistributedTask {
1041                task_id: format!("{}_task_{}", batch_id, instance_idx),
1042                task_type: TaskType::GenerateCounterfactuals,
1043                priority: 2,
1044                input_data: instance.to_owned().insert_axis(Axis(0)),
1045                metadata: {
1046                    let mut meta = HashMap::new();
1047                    meta.insert("batch_id".to_string(), batch_id.clone());
1048                    meta.insert("instance_idx".to_string(), instance_idx.to_string());
1049                    meta.insert("target_class".to_string(), target_class.to_string());
1050                    meta
1051                },
1052                created_at: Instant::now(),
1053            };
1054
1055            let task_id = self.coordinator.submit_task(task)?;
1056            task_ids.push(task_id);
1057        }
1058
1059        // Schedule all tasks
1060        self.coordinator.schedule_tasks()?;
1061
1062        // Collect counterfactuals
1063        let mut all_counterfactuals = Vec::new();
1064        for task_id in task_ids {
1065            let result = self.coordinator.get_result(&task_id)?;
1066            all_counterfactuals.push(result.result_data);
1067        }
1068
1069        Ok(all_counterfactuals)
1070    }
1071
1072    /// Split data into batches
1073    fn split_into_batches(
1074        &self,
1075        data: &Array2<Float>,
1076        batch_size: usize,
1077    ) -> SklResult<Vec<Array2<Float>>> {
1078        let n_samples = data.nrows();
1079        let mut batches = Vec::new();
1080
1081        for start_idx in (0..n_samples).step_by(batch_size) {
1082            let end_idx = (start_idx + batch_size).min(n_samples);
1083            let batch = data.slice(s![start_idx..end_idx, ..]).to_owned();
1084            batches.push(batch);
1085        }
1086
1087        Ok(batches)
1088    }
1089
1090    /// Split predictions into batches
1091    fn split_predictions(
1092        &self,
1093        predictions: &Array1<Float>,
1094        batch_size: usize,
1095    ) -> SklResult<Vec<Array1<Float>>> {
1096        let n_samples = predictions.len();
1097        let mut batches = Vec::new();
1098
1099        for start_idx in (0..n_samples).step_by(batch_size) {
1100            let end_idx = (start_idx + batch_size).min(n_samples);
1101            let batch = predictions.slice(s![start_idx..end_idx]).to_owned();
1102            batches.push(batch);
1103        }
1104
1105        Ok(batches)
1106    }
1107
1108    /// Aggregate SHAP results from multiple workers
1109    fn aggregate_shap_results(
1110        &self,
1111        results: Vec<Array1<Float>>,
1112        n_samples: usize,
1113        n_features: usize,
1114    ) -> SklResult<Array2<Float>> {
1115        let mut aggregated = Array2::zeros((n_samples, n_features));
1116
1117        let mut sample_idx = 0;
1118        for result in results {
1119            // Each result contains SHAP values for a batch of samples
1120            let batch_size = result.len() / n_features;
1121            for i in 0..batch_size {
1122                if sample_idx < n_samples {
1123                    for j in 0..n_features {
1124                        let result_idx = i * n_features + j;
1125                        if result_idx < result.len() {
1126                            aggregated[[sample_idx, j]] = result[result_idx];
1127                        }
1128                    }
1129                    sample_idx += 1;
1130                }
1131            }
1132        }
1133
1134        Ok(aggregated)
1135    }
1136
1137    /// Get cluster statistics
1138    pub fn get_cluster_statistics(&self) -> SklResult<ClusterStatistics> {
1139        self.coordinator.get_statistics()
1140    }
1141
1142    /// Get cluster health
1143    pub fn get_cluster_health(&self) -> SklResult<ClusterHealth> {
1144        self.coordinator.health_check()
1145    }
1146
1147    /// Scale cluster up by adding workers
1148    pub fn scale_up(&self, new_workers: Vec<WorkerConfig>) -> SklResult<()> {
1149        self.register_workers_from_config(new_workers)
1150    }
1151
1152    /// Scale cluster down by removing workers
1153    pub fn scale_down(&self, worker_ids: Vec<String>) -> SklResult<()> {
1154        for worker_id in worker_ids {
1155            self.coordinator.unregister_worker(&worker_id)?;
1156        }
1157        Ok(())
1158    }
1159}
1160
1161/// Worker configuration
1162#[derive(Debug, Clone)]
1163pub struct WorkerConfig {
1164    /// Worker identifier
1165    pub worker_id: String,
1166    /// Network address
1167    pub address: String,
1168    /// Worker capacity
1169    pub capacity: usize,
1170}
1171
1172/// Cached explanation
1173#[derive(Debug, Clone)]
1174struct CachedExplanation {
1175    /// Explanation data
1176    data: Array1<Float>,
1177    /// Cache timestamp
1178    cached_at: Instant,
1179    /// Cache hit count
1180    hit_count: usize,
1181}
1182
1183/// Batch computation tracking
1184#[derive(Debug, Clone)]
1185struct BatchComputation {
1186    /// Batch identifier
1187    batch_id: String,
1188    /// Task IDs in this batch
1189    task_ids: Vec<String>,
1190    /// Start time
1191    started_at: Instant,
1192    /// Completion status
1193    is_complete: bool,
1194}
1195
1196// Import for slicing
1197use scirs2_core::ndarray::{s, Axis};
1198
1199#[cfg(test)]
1200mod cluster_tests {
1201    use super::*;
1202
1203    #[test]
1204    fn test_cluster_orchestrator_creation() {
1205        let config = ClusterConfig::default();
1206        let orchestrator = ClusterExplanationOrchestrator::new(config);
1207        assert!(orchestrator.is_ok());
1208    }
1209
1210    #[test]
1211    fn test_register_workers_from_config() {
1212        let config = ClusterConfig::default();
1213        let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1214
1215        let worker_configs = vec![
1216            WorkerConfig {
1217                worker_id: "worker1".to_string(),
1218                address: "192.168.1.10:8080".to_string(),
1219                capacity: 10,
1220            },
1221            WorkerConfig {
1222                worker_id: "worker2".to_string(),
1223                address: "192.168.1.11:8080".to_string(),
1224                capacity: 10,
1225            },
1226        ];
1227
1228        let result = orchestrator.register_workers_from_config(worker_configs);
1229        assert!(result.is_ok());
1230
1231        let stats = orchestrator.get_cluster_statistics().unwrap();
1232        assert_eq!(stats.active_workers, 2);
1233    }
1234
1235    #[test]
1236    fn test_split_into_batches() {
1237        let config = ClusterConfig::default();
1238        let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1239
1240        let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as Float).collect()).unwrap();
1241        let batches = orchestrator.split_into_batches(&data, 3).unwrap();
1242
1243        assert_eq!(batches.len(), 4); // 10 samples / 3 per batch = 4 batches
1244        assert_eq!(batches[0].nrows(), 3);
1245        assert_eq!(batches[1].nrows(), 3);
1246        assert_eq!(batches[2].nrows(), 3);
1247        assert_eq!(batches[3].nrows(), 1); // Last batch has remainder
1248    }
1249
1250    #[test]
1251    fn test_split_predictions() {
1252        let config = ClusterConfig::default();
1253        let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1254
1255        let predictions = Array1::from_vec((0..10).map(|x| x as Float).collect());
1256        let batches = orchestrator.split_predictions(&predictions, 4).unwrap();
1257
1258        assert_eq!(batches.len(), 3); // 10 samples / 4 per batch = 3 batches
1259        assert_eq!(batches[0].len(), 4);
1260        assert_eq!(batches[1].len(), 4);
1261        assert_eq!(batches[2].len(), 2); // Last batch has remainder
1262    }
1263
1264    #[test]
1265    fn test_cluster_health() {
1266        let config = ClusterConfig::default();
1267        let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1268
1269        // Register workers
1270        let worker_configs = vec![WorkerConfig {
1271            worker_id: "worker1".to_string(),
1272            address: "192.168.1.10:8080".to_string(),
1273            capacity: 10,
1274        }];
1275        orchestrator
1276            .register_workers_from_config(worker_configs)
1277            .unwrap();
1278
1279        let health = orchestrator.get_cluster_health().unwrap();
1280        assert_eq!(health.status, HealthStatus::Healthy);
1281        assert_eq!(health.total_workers, 1);
1282        assert_eq!(health.healthy_workers, 1);
1283    }
1284
1285    #[test]
1286    fn test_scale_up() {
1287        let config = ClusterConfig::default();
1288        let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1289
1290        // Start with one worker
1291        let initial_workers = vec![WorkerConfig {
1292            worker_id: "worker1".to_string(),
1293            address: "192.168.1.10:8080".to_string(),
1294            capacity: 10,
1295        }];
1296        orchestrator
1297            .register_workers_from_config(initial_workers)
1298            .unwrap();
1299
1300        // Scale up
1301        let new_workers = vec![
1302            WorkerConfig {
1303                worker_id: "worker2".to_string(),
1304                address: "192.168.1.11:8080".to_string(),
1305                capacity: 10,
1306            },
1307            WorkerConfig {
1308                worker_id: "worker3".to_string(),
1309                address: "192.168.1.12:8080".to_string(),
1310                capacity: 10,
1311            },
1312        ];
1313        let result = orchestrator.scale_up(new_workers);
1314        assert!(result.is_ok());
1315
1316        let stats = orchestrator.get_cluster_statistics().unwrap();
1317        assert_eq!(stats.active_workers, 3);
1318    }
1319
1320    #[test]
1321    fn test_scale_down() {
1322        let config = ClusterConfig::default();
1323        let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1324
1325        // Register workers
1326        let worker_configs = vec![
1327            WorkerConfig {
1328                worker_id: "worker1".to_string(),
1329                address: "192.168.1.10:8080".to_string(),
1330                capacity: 10,
1331            },
1332            WorkerConfig {
1333                worker_id: "worker2".to_string(),
1334                address: "192.168.1.11:8080".to_string(),
1335                capacity: 10,
1336            },
1337            WorkerConfig {
1338                worker_id: "worker3".to_string(),
1339                address: "192.168.1.12:8080".to_string(),
1340                capacity: 10,
1341            },
1342        ];
1343        orchestrator
1344            .register_workers_from_config(worker_configs)
1345            .unwrap();
1346
1347        // Scale down
1348        let result = orchestrator.scale_down(vec!["worker3".to_string()]);
1349        assert!(result.is_ok());
1350
1351        let stats = orchestrator.get_cluster_statistics().unwrap();
1352        assert_eq!(stats.active_workers, 2);
1353    }
1354
1355    #[test]
1356    fn test_worker_config_creation() {
1357        let config = WorkerConfig {
1358            worker_id: "test_worker".to_string(),
1359            address: "localhost:8080".to_string(),
1360            capacity: 20,
1361        };
1362
1363        assert_eq!(config.worker_id, "test_worker");
1364        assert_eq!(config.address, "localhost:8080");
1365        assert_eq!(config.capacity, 20);
1366    }
1367}