1use crate::types::Float;
51use scirs2_core::ndarray::{Array1, Array2};
52use sklears_core::error::{Result as SklResult, SklearsError};
53use std::collections::{HashMap, VecDeque};
54use std::sync::{Arc, Mutex};
55use std::time::{Duration, Instant};
56
57#[derive(Debug, Clone)]
59pub struct ClusterConfig {
60 pub max_workers: usize,
62 pub task_timeout_seconds: u64,
64 pub retry_attempts: usize,
66 pub load_balancing_strategy: LoadBalancingStrategy,
68 pub enable_fault_tolerance: bool,
70 pub heartbeat_interval_seconds: u64,
72 pub max_queue_size: usize,
74 pub enable_auto_scaling: bool,
76 pub target_cpu_utilization: f64,
78}
79
80impl Default for ClusterConfig {
81 fn default() -> Self {
82 Self {
83 max_workers: 10,
84 task_timeout_seconds: 300,
85 retry_attempts: 3,
86 load_balancing_strategy: LoadBalancingStrategy::RoundRobin,
87 enable_fault_tolerance: true,
88 heartbeat_interval_seconds: 30,
89 max_queue_size: 1000,
90 enable_auto_scaling: false,
91 target_cpu_utilization: 0.7,
92 }
93 }
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum LoadBalancingStrategy {
99 RoundRobin,
101 LeastLoaded,
103 Weighted,
105 Random,
107 LocalityAware,
109}
110
111pub struct DistributedCoordinator {
113 config: ClusterConfig,
115 workers: Arc<Mutex<HashMap<String, WorkerNode>>>,
117 task_queue: Arc<Mutex<VecDeque<DistributedTask>>>,
119 results: Arc<Mutex<HashMap<String, TaskResult>>>,
121 assignments: Arc<Mutex<HashMap<String, String>>>, round_robin_counter: Arc<Mutex<usize>>,
125 statistics: Arc<Mutex<ClusterStatistics>>,
127}
128
129impl DistributedCoordinator {
130 pub fn new(config: ClusterConfig) -> SklResult<Self> {
132 Ok(Self {
133 config,
134 workers: Arc::new(Mutex::new(HashMap::new())),
135 task_queue: Arc::new(Mutex::new(VecDeque::new())),
136 results: Arc::new(Mutex::new(HashMap::new())),
137 assignments: Arc::new(Mutex::new(HashMap::new())),
138 round_robin_counter: Arc::new(Mutex::new(0)),
139 statistics: Arc::new(Mutex::new(ClusterStatistics::new())),
140 })
141 }
142
143 pub fn register_worker(&self, worker_id: String, address: String) -> SklResult<()> {
145 let mut workers = self.workers.lock().map_err(|_| {
146 SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
147 })?;
148
149 if workers.len() >= self.config.max_workers {
150 return Err(SklearsError::InvalidInput(
151 "Maximum number of workers reached".to_string(),
152 ));
153 }
154
155 let worker = WorkerNode::new(worker_id.clone(), address);
156 workers.insert(worker_id.clone(), worker);
157
158 let mut stats = self.statistics.lock().map_err(|_| {
159 SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
160 })?;
161 stats.active_workers += 1;
162
163 Ok(())
164 }
165
166 pub fn unregister_worker(&self, worker_id: &str) -> SklResult<()> {
168 let mut workers = self.workers.lock().map_err(|_| {
169 SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
170 })?;
171
172 if workers.remove(worker_id).is_some() {
173 let mut stats = self.statistics.lock().map_err(|_| {
174 SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
175 })?;
176 stats.active_workers = stats.active_workers.saturating_sub(1);
177 Ok(())
178 } else {
179 Err(SklearsError::InvalidInput(format!(
180 "Worker {} not found",
181 worker_id
182 )))
183 }
184 }
185
186 pub fn submit_task(&self, task: DistributedTask) -> SklResult<String> {
188 let mut queue = self
189 .task_queue
190 .lock()
191 .map_err(|_| SklearsError::InvalidInput("Failed to acquire queue lock".to_string()))?;
192
193 if queue.len() >= self.config.max_queue_size {
194 return Err(SklearsError::InvalidInput("Task queue is full".to_string()));
195 }
196
197 let task_id = task.task_id.clone();
198 queue.push_back(task);
199
200 let mut stats = self.statistics.lock().map_err(|_| {
201 SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
202 })?;
203 stats.total_tasks_submitted += 1;
204 stats.pending_tasks += 1;
205
206 Ok(task_id)
207 }
208
209 pub fn schedule_tasks(&self) -> SklResult<usize> {
211 let mut scheduled = 0;
212
213 loop {
214 let task = {
216 let mut queue = self.task_queue.lock().map_err(|_| {
217 SklearsError::InvalidInput("Failed to acquire queue lock".to_string())
218 })?;
219 queue.pop_front()
220 };
221
222 match task {
223 None => break, Some(task) => {
225 let worker_id = self.select_worker(&task)?;
227
228 self.assign_task_to_worker(task, &worker_id)?;
230
231 scheduled += 1;
232 }
233 }
234 }
235
236 Ok(scheduled)
237 }
238
239 fn select_worker(&self, task: &DistributedTask) -> SklResult<String> {
241 let workers = self.workers.lock().map_err(|_| {
242 SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
243 })?;
244
245 if workers.is_empty() {
246 return Err(SklearsError::InvalidInput(
247 "No workers available".to_string(),
248 ));
249 }
250
251 match self.config.load_balancing_strategy {
252 LoadBalancingStrategy::RoundRobin => {
253 let mut counter = self.round_robin_counter.lock().map_err(|_| {
254 SklearsError::InvalidInput("Failed to acquire counter lock".to_string())
255 })?;
256
257 let worker_ids: Vec<String> = workers.keys().cloned().collect();
258 let selected = &worker_ids[*counter % worker_ids.len()];
259 *counter += 1;
260
261 Ok(selected.clone())
262 }
263 LoadBalancingStrategy::LeastLoaded => {
264 let mut least_loaded_worker = None;
265 let mut min_load = usize::MAX;
266
267 for (worker_id, worker) in workers.iter() {
268 if worker.current_load < min_load {
269 min_load = worker.current_load;
270 least_loaded_worker = Some(worker_id.clone());
271 }
272 }
273
274 least_loaded_worker.ok_or_else(|| {
275 SklearsError::InvalidInput("Failed to find least loaded worker".to_string())
276 })
277 }
278 LoadBalancingStrategy::Weighted => {
279 let mut best_worker = None;
281 let mut best_score = 0.0;
282
283 for (worker_id, worker) in workers.iter() {
284 let score = (worker.capacity as f64) / (worker.current_load as f64 + 1.0);
285 if score > best_score {
286 best_score = score;
287 best_worker = Some(worker_id.clone());
288 }
289 }
290
291 best_worker.ok_or_else(|| {
292 SklearsError::InvalidInput("Failed to find weighted worker".to_string())
293 })
294 }
295 LoadBalancingStrategy::Random => {
296 use scirs2_core::random::{thread_rng, CoreRandom};
297
298 let worker_ids: Vec<String> = workers.keys().cloned().collect();
299 let mut rng = thread_rng();
300 let index = rng.gen_range(0..worker_ids.len());
301 Ok(worker_ids[index].clone())
302 }
303 LoadBalancingStrategy::LocalityAware => {
304 let mut best_worker = None;
306 let mut best_latency = Duration::from_secs(u64::MAX);
307
308 for (worker_id, worker) in workers.iter() {
309 if worker.network_latency < best_latency {
310 best_latency = worker.network_latency;
311 best_worker = Some(worker_id.clone());
312 }
313 }
314
315 best_worker.ok_or_else(|| {
316 SklearsError::InvalidInput("Failed to find locality-aware worker".to_string())
317 })
318 }
319 }
320 }
321
322 fn assign_task_to_worker(&self, task: DistributedTask, worker_id: &str) -> SklResult<()> {
324 {
326 let mut workers = self.workers.lock().map_err(|_| {
327 SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
328 })?;
329
330 let worker = workers.get_mut(worker_id).ok_or_else(|| {
331 SklearsError::InvalidInput(format!("Worker {} not found", worker_id))
332 })?;
333
334 worker.current_load += 1;
336 worker.total_tasks_processed += 1;
337 } {
341 let mut assignments = self.assignments.lock().map_err(|_| {
342 SklearsError::InvalidInput("Failed to acquire assignments lock".to_string())
343 })?;
344 assignments.insert(task.task_id.clone(), worker_id.to_string());
345 } {
349 let mut stats = self.statistics.lock().map_err(|_| {
350 SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
351 })?;
352 stats.pending_tasks = stats.pending_tasks.saturating_sub(1);
353 stats.running_tasks += 1;
354 } self.simulate_task_execution(task, worker_id.to_string())?;
359
360 Ok(())
361 }
362
363 fn simulate_task_execution(&self, task: DistributedTask, worker_id: String) -> SklResult<()> {
365 let result = TaskResult {
367 task_id: task.task_id.clone(),
368 worker_id: worker_id.clone(),
369 status: TaskStatus::Completed,
370 result_data: Array1::zeros(10), execution_time: Duration::from_millis(100),
372 retry_count: 0,
373 };
374
375 let mut results = self.results.lock().map_err(|_| {
377 SklearsError::InvalidInput("Failed to acquire results lock".to_string())
378 })?;
379 results.insert(task.task_id.clone(), result);
380
381 let mut workers = self.workers.lock().map_err(|_| {
383 SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
384 })?;
385 if let Some(worker) = workers.get_mut(&worker_id) {
386 worker.current_load = worker.current_load.saturating_sub(1);
387 }
388
389 let mut stats = self.statistics.lock().map_err(|_| {
391 SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
392 })?;
393 stats.running_tasks = stats.running_tasks.saturating_sub(1);
394 stats.completed_tasks += 1;
395
396 Ok(())
397 }
398
399 pub fn get_result(&self, task_id: &str) -> SklResult<TaskResult> {
401 let results = self.results.lock().map_err(|_| {
402 SklearsError::InvalidInput("Failed to acquire results lock".to_string())
403 })?;
404
405 results.get(task_id).cloned().ok_or_else(|| {
406 SklearsError::InvalidInput(format!("Result for task {} not found", task_id))
407 })
408 }
409
410 pub fn get_statistics(&self) -> SklResult<ClusterStatistics> {
412 let stats = self.statistics.lock().map_err(|_| {
413 SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
414 })?;
415
416 Ok(stats.clone())
417 }
418
419 pub fn get_worker_info(&self, worker_id: &str) -> SklResult<WorkerNode> {
421 let workers = self.workers.lock().map_err(|_| {
422 SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
423 })?;
424
425 workers
426 .get(worker_id)
427 .cloned()
428 .ok_or_else(|| SklearsError::InvalidInput(format!("Worker {} not found", worker_id)))
429 }
430
431 pub fn get_all_workers(&self) -> SklResult<Vec<WorkerNode>> {
433 let workers = self.workers.lock().map_err(|_| {
434 SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
435 })?;
436
437 Ok(workers.values().cloned().collect())
438 }
439
440 pub fn health_check(&self) -> SklResult<ClusterHealth> {
442 let workers = self.workers.lock().map_err(|_| {
443 SklearsError::InvalidInput("Failed to acquire workers lock".to_string())
444 })?;
445
446 let stats = self.statistics.lock().map_err(|_| {
447 SklearsError::InvalidInput("Failed to acquire statistics lock".to_string())
448 })?;
449
450 let total_workers = workers.len();
451 let healthy_workers = workers.values().filter(|w| w.is_healthy).count();
452
453 let health_status = if healthy_workers == 0 {
454 HealthStatus::Critical
455 } else if healthy_workers < total_workers / 2 {
456 HealthStatus::Degraded
457 } else if healthy_workers < total_workers {
458 HealthStatus::Warning
459 } else {
460 HealthStatus::Healthy
461 };
462
463 Ok(ClusterHealth {
464 status: health_status,
465 total_workers,
466 healthy_workers,
467 total_capacity: workers.values().map(|w| w.capacity).sum(),
468 current_load: workers.values().map(|w| w.current_load).sum(),
469 pending_tasks: stats.pending_tasks,
470 running_tasks: stats.running_tasks,
471 })
472 }
473}
474
475#[derive(Debug, Clone)]
477pub struct WorkerNode {
478 pub worker_id: String,
480 pub address: String,
482 pub capacity: usize,
484 pub current_load: usize,
486 pub total_tasks_processed: usize,
488 pub is_healthy: bool,
490 pub last_heartbeat: Instant,
492 pub network_latency: Duration,
494 pub cpu_utilization: f64,
496 pub memory_utilization: f64,
498}
499
500impl WorkerNode {
501 pub fn new(worker_id: String, address: String) -> Self {
503 Self {
504 worker_id,
505 address,
506 capacity: 10,
507 current_load: 0,
508 total_tasks_processed: 0,
509 is_healthy: true,
510 last_heartbeat: Instant::now(),
511 network_latency: Duration::from_millis(10),
512 cpu_utilization: 0.0,
513 memory_utilization: 0.0,
514 }
515 }
516
517 pub fn heartbeat(&mut self) {
519 self.last_heartbeat = Instant::now();
520 self.is_healthy = true;
521 }
522
523 pub fn is_overloaded(&self) -> bool {
525 self.current_load >= self.capacity
526 }
527
528 pub fn available_capacity(&self) -> usize {
530 self.capacity.saturating_sub(self.current_load)
531 }
532}
533
534#[derive(Debug, Clone)]
536pub struct DistributedTask {
537 pub task_id: String,
539 pub task_type: TaskType,
541 pub priority: usize,
543 pub input_data: Array2<Float>,
545 pub metadata: HashMap<String, String>,
547 pub created_at: Instant,
549}
550
551#[derive(Debug, Clone, PartialEq, Eq)]
553pub enum TaskType {
554 ComputeShap,
556 ComputePermutationImportance,
558 GenerateCounterfactuals,
560 ComputeFeatureImportance,
562 BatchExplanation,
564}
565
566#[derive(Debug, Clone)]
568pub struct TaskResult {
569 pub task_id: String,
571 pub worker_id: String,
573 pub status: TaskStatus,
575 pub result_data: Array1<Float>,
577 pub execution_time: Duration,
579 pub retry_count: usize,
581}
582
583#[derive(Debug, Clone, Copy, PartialEq, Eq)]
585pub enum TaskStatus {
586 Pending,
588 Running,
590 Completed,
592 Failed,
594 Cancelled,
596}
597
598#[derive(Debug, Clone)]
600pub struct ClusterStatistics {
601 pub active_workers: usize,
603 pub total_tasks_submitted: usize,
605 pub pending_tasks: usize,
607 pub running_tasks: usize,
609 pub completed_tasks: usize,
611 pub failed_tasks: usize,
613 pub avg_execution_time: Duration,
615 pub total_data_processed: usize,
617}
618
619impl ClusterStatistics {
620 fn new() -> Self {
621 Self {
622 active_workers: 0,
623 total_tasks_submitted: 0,
624 pending_tasks: 0,
625 running_tasks: 0,
626 completed_tasks: 0,
627 failed_tasks: 0,
628 avg_execution_time: Duration::from_secs(0),
629 total_data_processed: 0,
630 }
631 }
632}
633
634#[derive(Debug, Clone)]
636pub struct ClusterHealth {
637 pub status: HealthStatus,
639 pub total_workers: usize,
641 pub healthy_workers: usize,
643 pub total_capacity: usize,
645 pub current_load: usize,
647 pub pending_tasks: usize,
649 pub running_tasks: usize,
651}
652
653#[derive(Debug, Clone, Copy, PartialEq, Eq)]
655pub enum HealthStatus {
656 Healthy,
658 Warning,
660 Degraded,
662 Critical,
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669
670 #[test]
671 fn test_cluster_config_default() {
672 let config = ClusterConfig::default();
673 assert_eq!(config.max_workers, 10);
674 assert_eq!(config.task_timeout_seconds, 300);
675 assert_eq!(config.retry_attempts, 3);
676 assert_eq!(
677 config.load_balancing_strategy,
678 LoadBalancingStrategy::RoundRobin
679 );
680 assert!(config.enable_fault_tolerance);
681 }
682
683 #[test]
684 fn test_distributed_coordinator_creation() {
685 let config = ClusterConfig::default();
686 let coordinator = DistributedCoordinator::new(config);
687 assert!(coordinator.is_ok());
688 }
689
690 #[test]
691 fn test_register_worker() {
692 let config = ClusterConfig::default();
693 let coordinator = DistributedCoordinator::new(config).unwrap();
694
695 let result =
696 coordinator.register_worker("worker1".to_string(), "192.168.1.10:8080".to_string());
697 assert!(result.is_ok());
698
699 let stats = coordinator.get_statistics().unwrap();
700 assert_eq!(stats.active_workers, 1);
701 }
702
703 #[test]
704 fn test_register_multiple_workers() {
705 let config = ClusterConfig::default();
706 let coordinator = DistributedCoordinator::new(config).unwrap();
707
708 coordinator
709 .register_worker("worker1".to_string(), "192.168.1.10:8080".to_string())
710 .unwrap();
711 coordinator
712 .register_worker("worker2".to_string(), "192.168.1.11:8080".to_string())
713 .unwrap();
714
715 let stats = coordinator.get_statistics().unwrap();
716 assert_eq!(stats.active_workers, 2);
717 }
718
719 #[test]
720 fn test_register_worker_limit() {
721 let config = ClusterConfig {
722 max_workers: 2,
723 ..Default::default()
724 };
725 let coordinator = DistributedCoordinator::new(config).unwrap();
726
727 coordinator
728 .register_worker("worker1".to_string(), "192.168.1.10:8080".to_string())
729 .unwrap();
730 coordinator
731 .register_worker("worker2".to_string(), "192.168.1.11:8080".to_string())
732 .unwrap();
733
734 let result =
735 coordinator.register_worker("worker3".to_string(), "192.168.1.12:8080".to_string());
736 assert!(result.is_err());
737 }
738
739 #[test]
740 fn test_unregister_worker() {
741 let config = ClusterConfig::default();
742 let coordinator = DistributedCoordinator::new(config).unwrap();
743
744 coordinator
745 .register_worker("worker1".to_string(), "192.168.1.10:8080".to_string())
746 .unwrap();
747
748 let result = coordinator.unregister_worker("worker1");
749 assert!(result.is_ok());
750
751 let stats = coordinator.get_statistics().unwrap();
752 assert_eq!(stats.active_workers, 0);
753 }
754
755 #[test]
756 fn test_submit_task() {
757 let config = ClusterConfig::default();
758 let coordinator = DistributedCoordinator::new(config).unwrap();
759
760 let task = DistributedTask {
761 task_id: "task1".to_string(),
762 task_type: TaskType::ComputeShap,
763 priority: 1,
764 input_data: Array2::zeros((10, 5)),
765 metadata: HashMap::new(),
766 created_at: Instant::now(),
767 };
768
769 let result = coordinator.submit_task(task);
770 assert!(result.is_ok());
771 assert_eq!(result.unwrap(), "task1");
772
773 let stats = coordinator.get_statistics().unwrap();
774 assert_eq!(stats.total_tasks_submitted, 1);
775 assert_eq!(stats.pending_tasks, 1);
776 }
777
778 #[test]
779 fn test_schedule_tasks() {
780 let config = ClusterConfig::default();
781 let coordinator = DistributedCoordinator::new(config).unwrap();
782
783 coordinator
785 .register_worker("worker1".to_string(), "192.168.1.10:8080".to_string())
786 .unwrap();
787
788 let task = DistributedTask {
790 task_id: "task1".to_string(),
791 task_type: TaskType::ComputeShap,
792 priority: 1,
793 input_data: Array2::zeros((10, 5)),
794 metadata: HashMap::new(),
795 created_at: Instant::now(),
796 };
797 coordinator.submit_task(task).unwrap();
798
799 let scheduled = coordinator.schedule_tasks().unwrap();
801 assert_eq!(scheduled, 1);
802 }
803
804 #[test]
805 fn test_worker_node_creation() {
806 let worker = WorkerNode::new("worker1".to_string(), "192.168.1.10:8080".to_string());
807 assert_eq!(worker.worker_id, "worker1");
808 assert_eq!(worker.address, "192.168.1.10:8080");
809 assert_eq!(worker.capacity, 10);
810 assert_eq!(worker.current_load, 0);
811 assert!(worker.is_healthy);
812 }
813
814 #[test]
815 fn test_worker_node_overload() {
816 let mut worker = WorkerNode::new("worker1".to_string(), "192.168.1.10:8080".to_string());
817 worker.capacity = 5;
818 worker.current_load = 3;
819
820 assert!(!worker.is_overloaded());
821
822 worker.current_load = 5;
823 assert!(worker.is_overloaded());
824
825 worker.current_load = 6;
826 assert!(worker.is_overloaded());
827 }
828
829 #[test]
830 fn test_worker_node_available_capacity() {
831 let mut worker = WorkerNode::new("worker1".to_string(), "192.168.1.10:8080".to_string());
832 worker.capacity = 10;
833 worker.current_load = 3;
834
835 assert_eq!(worker.available_capacity(), 7);
836 }
837
838 #[test]
839 fn test_cluster_health_check() {
840 let config = ClusterConfig::default();
841 let coordinator = DistributedCoordinator::new(config).unwrap();
842
843 coordinator
844 .register_worker("worker1".to_string(), "192.168.1.10:8080".to_string())
845 .unwrap();
846 coordinator
847 .register_worker("worker2".to_string(), "192.168.1.11:8080".to_string())
848 .unwrap();
849
850 let health = coordinator.health_check().unwrap();
851 assert_eq!(health.status, HealthStatus::Healthy);
852 assert_eq!(health.total_workers, 2);
853 assert_eq!(health.healthy_workers, 2);
854 }
855
856 #[test]
857 fn test_load_balancing_strategies() {
858 assert_ne!(
860 LoadBalancingStrategy::RoundRobin,
861 LoadBalancingStrategy::LeastLoaded
862 );
863 assert_ne!(
864 LoadBalancingStrategy::RoundRobin,
865 LoadBalancingStrategy::Weighted
866 );
867 assert_ne!(
868 LoadBalancingStrategy::LeastLoaded,
869 LoadBalancingStrategy::Weighted
870 );
871 assert_ne!(
872 LoadBalancingStrategy::Random,
873 LoadBalancingStrategy::LocalityAware
874 );
875
876 assert_eq!(
878 LoadBalancingStrategy::RoundRobin,
879 LoadBalancingStrategy::RoundRobin
880 );
881 }
882}
883
884pub struct ClusterExplanationOrchestrator {
886 coordinator: Arc<DistributedCoordinator>,
888 config: ClusterConfig,
890 cache: Arc<Mutex<HashMap<String, CachedExplanation>>>,
892 active_batches: Arc<Mutex<HashMap<String, BatchComputation>>>,
894}
895
896impl ClusterExplanationOrchestrator {
897 pub fn new(config: ClusterConfig) -> SklResult<Self> {
899 let coordinator = Arc::new(DistributedCoordinator::new(config.clone())?);
900
901 Ok(Self {
902 coordinator,
903 config,
904 cache: Arc::new(Mutex::new(HashMap::new())),
905 active_batches: Arc::new(Mutex::new(HashMap::new())),
906 })
907 }
908
909 pub fn register_workers_from_config(&self, worker_configs: Vec<WorkerConfig>) -> SklResult<()> {
911 for worker_config in worker_configs {
912 self.coordinator
913 .register_worker(worker_config.worker_id, worker_config.address)?;
914 }
915 Ok(())
916 }
917
918 pub fn compute_shap_distributed(
920 &self,
921 data: &Array2<Float>,
922 background_data: &Array2<Float>,
923 batch_size: usize,
924 ) -> SklResult<Array2<Float>> {
925 let n_samples = data.nrows();
926 let n_features = data.ncols();
927
928 let batch_id = format!("shap_batch_{}", uuid::Uuid::new_v4());
930
931 let batches = self.split_into_batches(data, batch_size)?;
933
934 let mut task_ids = Vec::new();
936 for (batch_idx, batch) in batches.iter().enumerate() {
937 let task = DistributedTask {
938 task_id: format!("{}_task_{}", batch_id, batch_idx),
939 task_type: TaskType::ComputeShap,
940 priority: 1,
941 input_data: batch.clone(),
942 metadata: {
943 let mut meta = HashMap::new();
944 meta.insert("batch_id".to_string(), batch_id.clone());
945 meta.insert("batch_idx".to_string(), batch_idx.to_string());
946 meta
947 },
948 created_at: Instant::now(),
949 };
950
951 let task_id = self.coordinator.submit_task(task)?;
952 task_ids.push(task_id);
953 }
954
955 self.coordinator.schedule_tasks()?;
957
958 let mut all_results = Vec::new();
960 for task_id in task_ids {
961 let result = self.coordinator.get_result(&task_id)?;
962 all_results.push(result.result_data);
963 }
964
965 let aggregated = self.aggregate_shap_results(all_results, n_samples, n_features)?;
967
968 Ok(aggregated)
969 }
970
971 pub fn compute_feature_importance_distributed(
973 &self,
974 data: &Array2<Float>,
975 predictions: &Array1<Float>,
976 batch_size: usize,
977 ) -> SklResult<Array1<Float>> {
978 let n_samples = data.nrows();
979 let n_features = data.ncols();
980
981 let batch_id = format!("importance_batch_{}", uuid::Uuid::new_v4());
983
984 let data_batches = self.split_into_batches(data, batch_size)?;
986 let pred_batches = self.split_predictions(predictions, batch_size)?;
987
988 let mut task_ids = Vec::new();
990 for (batch_idx, (data_batch, pred_batch)) in
991 data_batches.iter().zip(pred_batches.iter()).enumerate()
992 {
993 let task = DistributedTask {
994 task_id: format!("{}_task_{}", batch_id, batch_idx),
995 task_type: TaskType::ComputeFeatureImportance,
996 priority: 1,
997 input_data: data_batch.clone(),
998 metadata: {
999 let mut meta = HashMap::new();
1000 meta.insert("batch_id".to_string(), batch_id.clone());
1001 meta.insert("batch_idx".to_string(), batch_idx.to_string());
1002 meta
1003 },
1004 created_at: Instant::now(),
1005 };
1006
1007 let task_id = self.coordinator.submit_task(task)?;
1008 task_ids.push(task_id);
1009 }
1010
1011 self.coordinator.schedule_tasks()?;
1013
1014 let mut importance_sum = Array1::zeros(n_features);
1016 let mut count = 0;
1017
1018 for task_id in task_ids {
1019 let result = self.coordinator.get_result(&task_id)?;
1020 importance_sum += &result.result_data.slice(s![..n_features]).to_owned();
1021 count += 1;
1022 }
1023
1024 Ok(importance_sum / (count as Float))
1026 }
1027
1028 pub fn generate_counterfactuals_distributed(
1030 &self,
1031 instances: &Array2<Float>,
1032 target_class: usize,
1033 n_counterfactuals_per_instance: usize,
1034 ) -> SklResult<Vec<Array1<Float>>> {
1035 let batch_id = format!("counterfactual_batch_{}", uuid::Uuid::new_v4());
1036
1037 let mut task_ids = Vec::new();
1039 for (instance_idx, instance) in instances.axis_iter(Axis(0)).enumerate() {
1040 let task = DistributedTask {
1041 task_id: format!("{}_task_{}", batch_id, instance_idx),
1042 task_type: TaskType::GenerateCounterfactuals,
1043 priority: 2,
1044 input_data: instance.to_owned().insert_axis(Axis(0)),
1045 metadata: {
1046 let mut meta = HashMap::new();
1047 meta.insert("batch_id".to_string(), batch_id.clone());
1048 meta.insert("instance_idx".to_string(), instance_idx.to_string());
1049 meta.insert("target_class".to_string(), target_class.to_string());
1050 meta
1051 },
1052 created_at: Instant::now(),
1053 };
1054
1055 let task_id = self.coordinator.submit_task(task)?;
1056 task_ids.push(task_id);
1057 }
1058
1059 self.coordinator.schedule_tasks()?;
1061
1062 let mut all_counterfactuals = Vec::new();
1064 for task_id in task_ids {
1065 let result = self.coordinator.get_result(&task_id)?;
1066 all_counterfactuals.push(result.result_data);
1067 }
1068
1069 Ok(all_counterfactuals)
1070 }
1071
1072 fn split_into_batches(
1074 &self,
1075 data: &Array2<Float>,
1076 batch_size: usize,
1077 ) -> SklResult<Vec<Array2<Float>>> {
1078 let n_samples = data.nrows();
1079 let mut batches = Vec::new();
1080
1081 for start_idx in (0..n_samples).step_by(batch_size) {
1082 let end_idx = (start_idx + batch_size).min(n_samples);
1083 let batch = data.slice(s![start_idx..end_idx, ..]).to_owned();
1084 batches.push(batch);
1085 }
1086
1087 Ok(batches)
1088 }
1089
1090 fn split_predictions(
1092 &self,
1093 predictions: &Array1<Float>,
1094 batch_size: usize,
1095 ) -> SklResult<Vec<Array1<Float>>> {
1096 let n_samples = predictions.len();
1097 let mut batches = Vec::new();
1098
1099 for start_idx in (0..n_samples).step_by(batch_size) {
1100 let end_idx = (start_idx + batch_size).min(n_samples);
1101 let batch = predictions.slice(s![start_idx..end_idx]).to_owned();
1102 batches.push(batch);
1103 }
1104
1105 Ok(batches)
1106 }
1107
1108 fn aggregate_shap_results(
1110 &self,
1111 results: Vec<Array1<Float>>,
1112 n_samples: usize,
1113 n_features: usize,
1114 ) -> SklResult<Array2<Float>> {
1115 let mut aggregated = Array2::zeros((n_samples, n_features));
1116
1117 let mut sample_idx = 0;
1118 for result in results {
1119 let batch_size = result.len() / n_features;
1121 for i in 0..batch_size {
1122 if sample_idx < n_samples {
1123 for j in 0..n_features {
1124 let result_idx = i * n_features + j;
1125 if result_idx < result.len() {
1126 aggregated[[sample_idx, j]] = result[result_idx];
1127 }
1128 }
1129 sample_idx += 1;
1130 }
1131 }
1132 }
1133
1134 Ok(aggregated)
1135 }
1136
1137 pub fn get_cluster_statistics(&self) -> SklResult<ClusterStatistics> {
1139 self.coordinator.get_statistics()
1140 }
1141
1142 pub fn get_cluster_health(&self) -> SklResult<ClusterHealth> {
1144 self.coordinator.health_check()
1145 }
1146
1147 pub fn scale_up(&self, new_workers: Vec<WorkerConfig>) -> SklResult<()> {
1149 self.register_workers_from_config(new_workers)
1150 }
1151
1152 pub fn scale_down(&self, worker_ids: Vec<String>) -> SklResult<()> {
1154 for worker_id in worker_ids {
1155 self.coordinator.unregister_worker(&worker_id)?;
1156 }
1157 Ok(())
1158 }
1159}
1160
1161#[derive(Debug, Clone)]
1163pub struct WorkerConfig {
1164 pub worker_id: String,
1166 pub address: String,
1168 pub capacity: usize,
1170}
1171
1172#[derive(Debug, Clone)]
1174struct CachedExplanation {
1175 data: Array1<Float>,
1177 cached_at: Instant,
1179 hit_count: usize,
1181}
1182
1183#[derive(Debug, Clone)]
1185struct BatchComputation {
1186 batch_id: String,
1188 task_ids: Vec<String>,
1190 started_at: Instant,
1192 is_complete: bool,
1194}
1195
1196use scirs2_core::ndarray::{s, Axis};
1198
1199#[cfg(test)]
1200mod cluster_tests {
1201 use super::*;
1202
1203 #[test]
1204 fn test_cluster_orchestrator_creation() {
1205 let config = ClusterConfig::default();
1206 let orchestrator = ClusterExplanationOrchestrator::new(config);
1207 assert!(orchestrator.is_ok());
1208 }
1209
1210 #[test]
1211 fn test_register_workers_from_config() {
1212 let config = ClusterConfig::default();
1213 let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1214
1215 let worker_configs = vec![
1216 WorkerConfig {
1217 worker_id: "worker1".to_string(),
1218 address: "192.168.1.10:8080".to_string(),
1219 capacity: 10,
1220 },
1221 WorkerConfig {
1222 worker_id: "worker2".to_string(),
1223 address: "192.168.1.11:8080".to_string(),
1224 capacity: 10,
1225 },
1226 ];
1227
1228 let result = orchestrator.register_workers_from_config(worker_configs);
1229 assert!(result.is_ok());
1230
1231 let stats = orchestrator.get_cluster_statistics().unwrap();
1232 assert_eq!(stats.active_workers, 2);
1233 }
1234
1235 #[test]
1236 fn test_split_into_batches() {
1237 let config = ClusterConfig::default();
1238 let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1239
1240 let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as Float).collect()).unwrap();
1241 let batches = orchestrator.split_into_batches(&data, 3).unwrap();
1242
1243 assert_eq!(batches.len(), 4); assert_eq!(batches[0].nrows(), 3);
1245 assert_eq!(batches[1].nrows(), 3);
1246 assert_eq!(batches[2].nrows(), 3);
1247 assert_eq!(batches[3].nrows(), 1); }
1249
1250 #[test]
1251 fn test_split_predictions() {
1252 let config = ClusterConfig::default();
1253 let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1254
1255 let predictions = Array1::from_vec((0..10).map(|x| x as Float).collect());
1256 let batches = orchestrator.split_predictions(&predictions, 4).unwrap();
1257
1258 assert_eq!(batches.len(), 3); assert_eq!(batches[0].len(), 4);
1260 assert_eq!(batches[1].len(), 4);
1261 assert_eq!(batches[2].len(), 2); }
1263
1264 #[test]
1265 fn test_cluster_health() {
1266 let config = ClusterConfig::default();
1267 let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1268
1269 let worker_configs = vec![WorkerConfig {
1271 worker_id: "worker1".to_string(),
1272 address: "192.168.1.10:8080".to_string(),
1273 capacity: 10,
1274 }];
1275 orchestrator
1276 .register_workers_from_config(worker_configs)
1277 .unwrap();
1278
1279 let health = orchestrator.get_cluster_health().unwrap();
1280 assert_eq!(health.status, HealthStatus::Healthy);
1281 assert_eq!(health.total_workers, 1);
1282 assert_eq!(health.healthy_workers, 1);
1283 }
1284
1285 #[test]
1286 fn test_scale_up() {
1287 let config = ClusterConfig::default();
1288 let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1289
1290 let initial_workers = vec![WorkerConfig {
1292 worker_id: "worker1".to_string(),
1293 address: "192.168.1.10:8080".to_string(),
1294 capacity: 10,
1295 }];
1296 orchestrator
1297 .register_workers_from_config(initial_workers)
1298 .unwrap();
1299
1300 let new_workers = vec![
1302 WorkerConfig {
1303 worker_id: "worker2".to_string(),
1304 address: "192.168.1.11:8080".to_string(),
1305 capacity: 10,
1306 },
1307 WorkerConfig {
1308 worker_id: "worker3".to_string(),
1309 address: "192.168.1.12:8080".to_string(),
1310 capacity: 10,
1311 },
1312 ];
1313 let result = orchestrator.scale_up(new_workers);
1314 assert!(result.is_ok());
1315
1316 let stats = orchestrator.get_cluster_statistics().unwrap();
1317 assert_eq!(stats.active_workers, 3);
1318 }
1319
1320 #[test]
1321 fn test_scale_down() {
1322 let config = ClusterConfig::default();
1323 let orchestrator = ClusterExplanationOrchestrator::new(config).unwrap();
1324
1325 let worker_configs = vec![
1327 WorkerConfig {
1328 worker_id: "worker1".to_string(),
1329 address: "192.168.1.10:8080".to_string(),
1330 capacity: 10,
1331 },
1332 WorkerConfig {
1333 worker_id: "worker2".to_string(),
1334 address: "192.168.1.11:8080".to_string(),
1335 capacity: 10,
1336 },
1337 WorkerConfig {
1338 worker_id: "worker3".to_string(),
1339 address: "192.168.1.12:8080".to_string(),
1340 capacity: 10,
1341 },
1342 ];
1343 orchestrator
1344 .register_workers_from_config(worker_configs)
1345 .unwrap();
1346
1347 let result = orchestrator.scale_down(vec!["worker3".to_string()]);
1349 assert!(result.is_ok());
1350
1351 let stats = orchestrator.get_cluster_statistics().unwrap();
1352 assert_eq!(stats.active_workers, 2);
1353 }
1354
1355 #[test]
1356 fn test_worker_config_creation() {
1357 let config = WorkerConfig {
1358 worker_id: "test_worker".to_string(),
1359 address: "localhost:8080".to_string(),
1360 capacity: 20,
1361 };
1362
1363 assert_eq!(config.worker_id, "test_worker");
1364 assert_eq!(config.address, "localhost:8080");
1365 assert_eq!(config.capacity, 20);
1366 }
1367}