1use crate::error::{SpatialError, SpatialResult};
60use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
61use scirs2_core::random::quick::random_f64;
62use std::collections::{BTreeMap, HashMap, VecDeque};
63use std::sync::Arc;
64use std::time::{Duration, Instant};
65#[cfg(feature = "async")]
66use tokio::sync::{mpsc, RwLock as TokioRwLock};
67
68#[derive(Debug, Clone)]
70pub struct NodeConfig {
71 pub node_count: usize,
73 pub fault_tolerance: bool,
75 pub load_balancing: bool,
77 pub compression: bool,
79 pub communication_timeout_ms: u64,
81 pub heartbeat_interval_ms: u64,
83 pub max_retries: usize,
85 pub replication_factor: usize,
87}
88
89impl Default for NodeConfig {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl NodeConfig {
96 pub fn new() -> Self {
98 Self {
99 node_count: 1,
100 fault_tolerance: false,
101 load_balancing: false,
102 compression: false,
103 communication_timeout_ms: 5000,
104 heartbeat_interval_ms: 1000,
105 max_retries: 3,
106 replication_factor: 1,
107 }
108 }
109
110 pub fn with_node_count(mut self, count: usize) -> Self {
112 self.node_count = count;
113 self
114 }
115
116 pub fn with_fault_tolerance(mut self, enabled: bool) -> Self {
118 self.fault_tolerance = enabled;
119 if enabled && self.replication_factor < 2 {
120 self.replication_factor = 2;
121 }
122 self
123 }
124
125 pub fn with_load_balancing(mut self, enabled: bool) -> Self {
127 self.load_balancing = enabled;
128 self
129 }
130
131 pub fn with_compression(mut self, enabled: bool) -> Self {
133 self.compression = enabled;
134 self
135 }
136}
137
138#[derive(Debug)]
140pub struct DistributedSpatialCluster {
141 config: NodeConfig,
143 nodes: Vec<Arc<TokioRwLock<NodeInstance>>>,
145 #[allow(dead_code)]
147 master_node_id: usize,
148 partitions: Arc<TokioRwLock<HashMap<usize, DataPartition>>>,
150 load_balancer: Arc<TokioRwLock<LoadBalancer>>,
152 #[allow(dead_code)]
154 fault_detector: Arc<TokioRwLock<FaultDetector>>,
155 communication: Arc<TokioRwLock<CommunicationLayer>>,
157 cluster_state: Arc<TokioRwLock<ClusterState>>,
159}
160
161#[derive(Debug)]
163pub struct NodeInstance {
164 pub node_id: usize,
166 pub status: NodeStatus,
168 pub local_data: Option<Array2<f64>>,
170 pub local_index: Option<DistributedSpatialIndex>,
172 pub load_metrics: LoadMetrics,
174 pub last_heartbeat: Instant,
176 pub assigned_partitions: Vec<usize>,
178}
179
180#[derive(Debug, Clone, PartialEq)]
182pub enum NodeStatus {
183 Active,
184 Inactive,
185 Failed,
186 Recovering,
187 Joining,
188 Leaving,
189}
190
191#[derive(Debug, Clone)]
193pub struct DataPartition {
194 pub partition_id: usize,
196 pub bounds: SpatialBounds,
198 pub data: Array2<f64>,
200 pub primary_node: usize,
202 pub replica_nodes: Vec<usize>,
204 pub size: usize,
206 pub last_modified: Instant,
208}
209
210#[derive(Debug, Clone)]
212pub struct SpatialBounds {
213 pub min_coords: Array1<f64>,
215 pub max_coords: Array1<f64>,
217}
218
219impl SpatialBounds {
220 pub fn contains(&self, point: &ArrayView1<f64>) -> bool {
222 point
223 .iter()
224 .zip(self.min_coords.iter())
225 .zip(self.max_coords.iter())
226 .all(|((&coord, &min_coord), &max_coord)| coord >= min_coord && coord <= max_coord)
227 }
228
229 pub fn volume(&self) -> f64 {
231 self.min_coords
232 .iter()
233 .zip(self.max_coords.iter())
234 .map(|(&min_coord, &max_coord)| max_coord - min_coord)
235 .product()
236 }
237}
238
239#[derive(Debug)]
241pub struct LoadBalancer {
242 #[allow(dead_code)]
244 node_loads: HashMap<usize, LoadMetrics>,
245 #[allow(dead_code)]
247 strategy: LoadBalancingStrategy,
248 #[allow(dead_code)]
250 last_rebalance: Instant,
251 #[allow(dead_code)]
253 load_threshold: f64,
254}
255
256#[derive(Debug, Clone)]
258pub enum LoadBalancingStrategy {
259 RoundRobin,
260 LeastLoaded,
261 ProportionalLoad,
262 AdaptiveLoad,
263}
264
265#[derive(Debug, Clone)]
267pub struct LoadMetrics {
268 pub cpu_utilization: f64,
270 pub memory_utilization: f64,
272 pub network_utilization: f64,
274 pub partition_count: usize,
276 pub operation_count: usize,
278 pub last_update: Instant,
280}
281
282impl LoadMetrics {
283 pub fn load_score(&self) -> f64 {
285 0.4 * self.cpu_utilization
286 + 0.3 * self.memory_utilization
287 + 0.2 * self.network_utilization
288 + 0.1 * (self.partition_count as f64 / 10.0).min(1.0)
289 }
290}
291
292#[derive(Debug)]
294pub struct FaultDetector {
295 #[allow(dead_code)]
297 node_health: HashMap<usize, NodeHealth>,
298 #[allow(dead_code)]
300 failure_threshold: Duration,
301 #[allow(dead_code)]
303 recovery_strategies: HashMap<FailureType, RecoveryStrategy>,
304}
305
306#[derive(Debug, Clone)]
308pub struct NodeHealth {
309 pub last_contact: Instant,
311 pub consecutive_failures: usize,
313 pub response_times: VecDeque<Duration>,
315 pub health_score: f64,
317}
318
319#[derive(Debug, Clone, Hash, PartialEq, Eq)]
321pub enum FailureType {
322 NodeUnresponsive,
323 HighLatency,
324 ResourceExhaustion,
325 PartialFailure,
326 NetworkPartition,
327}
328
329#[derive(Debug, Clone)]
331pub enum RecoveryStrategy {
332 Restart,
333 Migrate,
334 Replicate,
335 Isolate,
336 WaitAndRetry,
337}
338
339#[derive(Debug)]
341pub struct CommunicationLayer {
342 #[allow(dead_code)]
344 channels: HashMap<usize, mpsc::Sender<DistributedMessage>>,
345 #[allow(dead_code)]
347 compression_enabled: bool,
348 stats: CommunicationStats,
350}
351
352#[derive(Debug, Clone)]
354pub struct CommunicationStats {
355 pub messages_sent: u64,
357 pub messages_received: u64,
359 pub bytes_sent: u64,
361 pub bytes_received: u64,
363 pub average_latency_ms: f64,
365}
366
367#[derive(Debug, Clone)]
369pub enum DistributedMessage {
370 Heartbeat {
372 node_id: usize,
373 timestamp: Instant,
374 load_metrics: LoadMetrics,
375 },
376 DataDistribution {
378 partition_id: usize,
379 data: Array2<f64>,
380 bounds: SpatialBounds,
381 },
382 Query {
384 query_id: usize,
385 query_type: QueryType,
386 parameters: QueryParameters,
387 },
388 QueryResponse {
390 query_id: usize,
391 results: QueryResults,
392 node_id: usize,
393 },
394 LoadBalance { rebalance_plan: RebalancePlan },
396 FaultTolerance {
398 failure_type: FailureType,
399 affected_nodes: Vec<usize>,
400 recovery_plan: RecoveryPlan,
401 },
402}
403
404#[derive(Debug, Clone)]
406pub enum QueryType {
407 KNearestNeighbors,
408 RangeSearch,
409 Clustering,
410 DistanceMatrix,
411}
412
413#[derive(Debug, Clone)]
415pub struct QueryParameters {
416 pub query_point: Option<Array1<f64>>,
418 pub radius: Option<f64>,
420 pub k: Option<usize>,
422 pub num_clusters: Option<usize>,
424 pub extra_params: HashMap<String, f64>,
426}
427
428#[derive(Debug, Clone)]
430pub enum QueryResults {
431 NearestNeighbors {
432 indices: Vec<usize>,
433 distances: Vec<f64>,
434 },
435 RangeSearch {
436 indices: Vec<usize>,
437 points: Array2<f64>,
438 },
439 Clustering {
440 centroids: Array2<f64>,
441 assignments: Array1<usize>,
442 },
443 DistanceMatrix {
444 matrix: Array2<f64>,
445 },
446}
447
448#[derive(Debug, Clone)]
450pub struct RebalancePlan {
451 pub migrations: Vec<PartitionMigration>,
453 pub load_improvement: f64,
455 pub migration_cost: f64,
457}
458
459#[derive(Debug, Clone)]
461pub struct PartitionMigration {
462 pub partition_id: usize,
464 pub from_node: usize,
466 pub to_node: usize,
468 pub priority: f64,
470}
471
472#[derive(Debug, Clone)]
474pub struct RecoveryPlan {
475 pub actions: Vec<RecoveryAction>,
477 pub estimated_recovery_time: Duration,
479 pub success_probability: f64,
481}
482
483#[derive(Debug, Clone)]
485pub struct RecoveryAction {
486 pub action_type: RecoveryStrategy,
488 pub target_node: usize,
490 pub parameters: HashMap<String, String>,
492}
493
494#[derive(Debug)]
496pub struct ClusterState {
497 pub active_nodes: Vec<usize>,
499 pub total_data_points: usize,
501 pub total_partitions: usize,
503 pub health_score: f64,
505 pub performance_metrics: ClusterPerformanceMetrics,
507}
508
509#[derive(Debug, Clone)]
511pub struct ClusterPerformanceMetrics {
512 pub avg_query_latency_ms: f64,
514 pub throughput_qps: f64,
516 pub load_balance_score: f64,
518 pub fault_tolerance_level: f64,
520}
521
522#[derive(Debug)]
524pub struct DistributedSpatialIndex {
525 pub local_index: LocalSpatialIndex,
527 pub global_metadata: GlobalIndexMetadata,
529 pub routing_table: RoutingTable,
531}
532
533#[derive(Debug)]
535pub struct LocalSpatialIndex {
536 pub kdtree: Option<crate::KDTree<f64, crate::EuclideanDistance<f64>>>,
538 pub bounds: SpatialBounds,
540 pub stats: IndexStatistics,
542}
543
544#[derive(Debug, Clone)]
546pub struct GlobalIndexMetadata {
547 pub global_bounds: SpatialBounds,
549 pub partition_map: HashMap<usize, SpatialBounds>,
551 pub version: usize,
553}
554
555#[derive(Debug)]
557pub struct RoutingTable {
558 pub entries: BTreeMap<SpatialKey, Vec<usize>>,
560 pub cache: HashMap<SpatialKey, Vec<usize>>,
562}
563
564#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
566pub struct SpatialKey {
567 pub z_order: u64,
569 pub level: usize,
571}
572
573#[derive(Debug, Clone)]
575pub struct IndexStatistics {
576 pub build_time_ms: f64,
578 pub memory_usage_bytes: usize,
580 pub query_count: u64,
582 pub avg_query_time_ms: f64,
584}
585
586impl DistributedSpatialCluster {
587 pub fn new(config: NodeConfig) -> SpatialResult<Self> {
589 let mut nodes = Vec::new();
590 let mut channels = HashMap::new();
591
592 for node_id in 0..config.node_count {
594 let (sender, receiver) = mpsc::channel(1000);
595 channels.insert(node_id, sender);
596
597 let node = NodeInstance {
598 node_id,
599 status: NodeStatus::Active,
600 local_data: None,
601 local_index: None,
602 load_metrics: LoadMetrics {
603 cpu_utilization: 0.0,
604 memory_utilization: 0.0,
605 network_utilization: 0.0,
606 partition_count: 0,
607 operation_count: 0,
608 last_update: Instant::now(),
609 },
610 last_heartbeat: Instant::now(),
611 assigned_partitions: Vec::new(),
612 };
613
614 nodes.push(Arc::new(TokioRwLock::new(node)));
615 }
616
617 let load_balancer = LoadBalancer {
618 node_loads: HashMap::new(),
619 strategy: LoadBalancingStrategy::AdaptiveLoad,
620 last_rebalance: Instant::now(),
621 load_threshold: 0.8,
622 };
623
624 let fault_detector = FaultDetector {
625 node_health: HashMap::new(),
626 failure_threshold: Duration::from_secs(10),
627 recovery_strategies: HashMap::new(),
628 };
629
630 let communication = CommunicationLayer {
631 channels,
632 compression_enabled: config.compression,
633 stats: CommunicationStats {
634 messages_sent: 0,
635 messages_received: 0,
636 bytes_sent: 0,
637 bytes_received: 0,
638 average_latency_ms: 0.0,
639 },
640 };
641
642 let cluster_state = ClusterState {
643 active_nodes: (0..config.node_count).collect(),
644 total_data_points: 0,
645 total_partitions: 0,
646 health_score: 1.0,
647 performance_metrics: ClusterPerformanceMetrics {
648 avg_query_latency_ms: 0.0,
649 throughput_qps: 0.0,
650 load_balance_score: 1.0,
651 fault_tolerance_level: if config.fault_tolerance { 0.8 } else { 0.0 },
652 },
653 };
654
655 Ok(Self {
656 config,
657 nodes,
658 master_node_id: 0,
659 partitions: Arc::new(TokioRwLock::new(HashMap::new())),
660 load_balancer: Arc::new(TokioRwLock::new(load_balancer)),
661 fault_detector: Arc::new(TokioRwLock::new(fault_detector)),
662 communication: Arc::new(TokioRwLock::new(communication)),
663 cluster_state: Arc::new(TokioRwLock::new(cluster_state)),
664 })
665 }
666
667 #[allow(dead_code)]
669 fn default_recovery_strategies(&self) -> HashMap<FailureType, RecoveryStrategy> {
670 let mut strategies = HashMap::new();
671 strategies.insert(FailureType::NodeUnresponsive, RecoveryStrategy::Restart);
672 strategies.insert(FailureType::HighLatency, RecoveryStrategy::WaitAndRetry);
673 strategies.insert(FailureType::ResourceExhaustion, RecoveryStrategy::Migrate);
674 strategies.insert(FailureType::PartialFailure, RecoveryStrategy::Replicate);
675 strategies.insert(FailureType::NetworkPartition, RecoveryStrategy::Isolate);
676 strategies
677 }
678
679 pub async fn distribute_data(&mut self, data: &ArrayView2<'_, f64>) -> SpatialResult<()> {
681 let (n_points, n_dims) = data.dim();
682
683 let partitions = self.create_spatial_partitions(data).await?;
685
686 self.assign_partitions_to_nodes(&partitions).await?;
688
689 self.build_distributed_indices().await?;
691
692 {
694 let mut state = self.cluster_state.write().await;
695 state.total_data_points = n_points;
696 state.total_partitions = partitions.len();
697 }
698
699 Ok(())
700 }
701
702 async fn create_spatial_partitions(
704 &self,
705 data: &ArrayView2<'_, f64>,
706 ) -> SpatialResult<Vec<DataPartition>> {
707 let (n_points, n_dims) = data.dim();
708 let target_partitions = self.config.node_count * 2; let mut min_coords = Array1::from_elem(n_dims, f64::INFINITY);
712 let mut max_coords = Array1::from_elem(n_dims, f64::NEG_INFINITY);
713
714 for point in data.outer_iter() {
715 for (i, &coord) in point.iter().enumerate() {
716 min_coords[i] = min_coords[i].min(coord);
717 max_coords[i] = max_coords[i].max(coord);
718 }
719 }
720
721 let global_bounds = SpatialBounds {
722 min_coords,
723 max_coords,
724 };
725
726 let mut point_z_orders = Vec::new();
728 for (i, point) in data.outer_iter().enumerate() {
729 let z_order = self.calculate_z_order(&point.to_owned(), &global_bounds, 16);
730 point_z_orders.push((i, z_order, point.to_owned()));
731 }
732
733 point_z_orders.sort_by_key(|(_, z_order_, _)| *z_order_);
735
736 let points_per_partition = n_points.div_ceil(target_partitions);
738 let mut partitions = Vec::new();
739
740 for partition_id in 0..target_partitions {
741 let start_idx = partition_id * points_per_partition;
742 let end_idx = ((partition_id + 1) * points_per_partition).min(n_points);
743
744 if start_idx >= n_points {
745 break;
746 }
747
748 let partition_size = end_idx - start_idx;
750 let mut partition_data = Array2::zeros((partition_size, n_dims));
751 let mut partition_min = Array1::from_elem(n_dims, f64::INFINITY);
752 let mut partition_max = Array1::from_elem(n_dims, f64::NEG_INFINITY);
753
754 for (i, (_, _, point)) in point_z_orders[start_idx..end_idx].iter().enumerate() {
755 partition_data.row_mut(i).assign(point);
756
757 for (j, &coord) in point.iter().enumerate() {
758 partition_min[j] = partition_min[j].min(coord);
759 partition_max[j] = partition_max[j].max(coord);
760 }
761 }
762
763 let partition_bounds = SpatialBounds {
764 min_coords: partition_min,
765 max_coords: partition_max,
766 };
767
768 let partition = DataPartition {
769 partition_id,
770 bounds: partition_bounds,
771 data: partition_data,
772 primary_node: partition_id % self.config.node_count,
773 replica_nodes: if self.config.fault_tolerance {
774 vec![(partition_id + 1) % self.config.node_count]
775 } else {
776 Vec::new()
777 },
778 size: partition_size,
779 last_modified: Instant::now(),
780 };
781
782 partitions.push(partition);
783 }
784
785 Ok(partitions)
786 }
787
788 fn calculate_z_order(
790 &self,
791 point: &Array1<f64>,
792 bounds: &SpatialBounds,
793 resolution: usize,
794 ) -> u64 {
795 let mut z_order = 0u64;
796
797 for bit in 0..resolution {
798 for (dim, ((&coord, &min_coord), &max_coord)) in point
799 .iter()
800 .zip(bounds.min_coords.iter())
801 .zip(bounds.max_coords.iter())
802 .enumerate()
803 {
804 if dim >= 3 {
805 break;
806 } let normalized = if max_coord > min_coord {
809 (coord - min_coord) / (max_coord - min_coord)
810 } else {
811 0.5
812 };
813
814 let bit_val = if normalized >= 0.5 { 1u64 } else { 0u64 };
815 let bit_pos = bit * 3 + dim; if bit_pos < 64 {
818 z_order |= bit_val << bit_pos;
819 }
820 }
821 }
822
823 z_order
824 }
825
826 async fn assign_partitions_to_nodes(
828 &mut self,
829 partitions: &[DataPartition],
830 ) -> SpatialResult<()> {
831 let mut partition_map = HashMap::new();
832
833 for partition in partitions {
834 partition_map.insert(partition.partition_id, partition.clone());
835
836 let primary_node = &self.nodes[partition.primary_node];
838 {
839 let mut node = primary_node.write().await;
840 node.assigned_partitions.push(partition.partition_id);
841
842 if let Some(ref existing_data) = node.local_data {
844 let (existing_rows, cols) = existing_data.dim();
846 let (new_rows_, _) = partition.data.dim();
847 let total_rows = existing_rows + new_rows_;
848
849 let mut combined_data = Array2::zeros((total_rows, cols));
850 combined_data
851 .slice_mut(s![..existing_rows, ..])
852 .assign(existing_data);
853 combined_data
854 .slice_mut(s![existing_rows.., ..])
855 .assign(&partition.data);
856 node.local_data = Some(combined_data);
857 } else {
858 node.local_data = Some(partition.data.clone());
859 }
860
861 node.load_metrics.partition_count += 1;
862 }
863
864 for &replica_node_id in &partition.replica_nodes {
866 let replica_node = &self.nodes[replica_node_id];
867 let mut node = replica_node.write().await;
868 node.assigned_partitions.push(partition.partition_id);
869
870 if let Some(ref existing_data) = node.local_data {
872 let (existing_rows, cols) = existing_data.dim();
874 let (new_rows_, _) = partition.data.dim();
875 let total_rows = existing_rows + new_rows_;
876
877 let mut combined_data = Array2::zeros((total_rows, cols));
878 combined_data
879 .slice_mut(s![..existing_rows, ..])
880 .assign(existing_data);
881 combined_data
882 .slice_mut(s![existing_rows.., ..])
883 .assign(&partition.data);
884 node.local_data = Some(combined_data);
885 } else {
886 node.local_data = Some(partition.data.clone());
887 }
888
889 node.load_metrics.partition_count += 1;
890 }
891 }
892
893 {
894 let mut partitions_lock = self.partitions.write().await;
895 *partitions_lock = partition_map;
896 }
897
898 Ok(())
899 }
900
901 async fn build_distributed_indices(&mut self) -> SpatialResult<()> {
903 for node_arc in &self.nodes {
905 let mut node = node_arc.write().await;
906
907 if let Some(ref local_data) = node.local_data {
908 let (n_points, n_dims) = local_data.dim();
910 let mut min_coords = Array1::from_elem(n_dims, f64::INFINITY);
911 let mut max_coords = Array1::from_elem(n_dims, f64::NEG_INFINITY);
912
913 for point in local_data.outer_iter() {
914 for (i, &coord) in point.iter().enumerate() {
915 min_coords[i] = min_coords[i].min(coord);
916 max_coords[i] = max_coords[i].max(coord);
917 }
918 }
919
920 let local_bounds = SpatialBounds {
921 min_coords,
922 max_coords,
923 };
924
925 let kdtree = crate::KDTree::new(local_data)?;
927
928 let local_index = LocalSpatialIndex {
929 kdtree: Some(kdtree),
930 bounds: local_bounds.clone(),
931 stats: IndexStatistics {
932 build_time_ms: 0.0, memory_usage_bytes: n_points * n_dims * 8, query_count: 0,
935 avg_query_time_ms: 0.0,
936 },
937 };
938
939 let routing_table = RoutingTable {
941 entries: BTreeMap::new(),
942 cache: HashMap::new(),
943 };
944
945 let global_metadata = GlobalIndexMetadata {
947 global_bounds: local_bounds.clone(), partition_map: HashMap::new(),
949 version: 1,
950 };
951
952 let distributed_index = DistributedSpatialIndex {
953 local_index,
954 global_metadata,
955 routing_table,
956 };
957
958 node.local_index = Some(distributed_index);
959 }
960 }
961
962 Ok(())
963 }
964
965 pub async fn distributed_kmeans(
967 &mut self,
968 k: usize,
969 max_iterations: usize,
970 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
971 let initial_centroids = self.initialize_distributed_centroids(k).await?;
973 let mut centroids = initial_centroids;
974
975 for _iteration in 0..max_iterations {
976 let local_assignments = self.distributed_assignment_step(¢roids).await?;
978
979 let new_centroids = self
981 .distributed_centroid_update(&local_assignments, k)
982 .await?;
983
984 let centroid_change = self.calculate_centroid_change(¢roids, &new_centroids);
986 if centroid_change < 1e-6 {
987 break;
988 }
989
990 centroids = new_centroids;
991 }
992
993 let final_assignments = self.collect_final_assignments(¢roids).await?;
995
996 Ok((centroids, final_assignments))
997 }
998
999 async fn initialize_distributed_centroids(&self, k: usize) -> SpatialResult<Array2<f64>> {
1001 let first_centroid = self.get_random_point_from_cluster().await?;
1003
1004 let n_dims = first_centroid.len();
1005 let mut centroids = Array2::zeros((k, n_dims));
1006 centroids.row_mut(0).assign(&first_centroid);
1007
1008 for i in 1..k {
1010 let distances = self
1011 .compute_distributed_distances(¢roids.slice(s![..i, ..]))
1012 .await?;
1013 let next_centroid = self.select_next_centroid_weighted(&distances).await?;
1014 centroids.row_mut(i).assign(&next_centroid);
1015 }
1016
1017 Ok(centroids)
1018 }
1019
1020 async fn get_random_point_from_cluster(&self) -> SpatialResult<Array1<f64>> {
1022 for node_arc in &self.nodes {
1023 let node = node_arc.read().await;
1024 if let Some(ref local_data) = node.local_data {
1025 if local_data.nrows() > 0 {
1026 let idx = (random_f64() * local_data.nrows() as f64) as usize;
1027 return Ok(local_data.row(idx).to_owned());
1028 }
1029 }
1030 }
1031
1032 Err(SpatialError::InvalidInput(
1033 "No data found in cluster".to_string(),
1034 ))
1035 }
1036
1037 async fn compute_distributed_distances(
1039 &self,
1040 centroids: &ArrayView2<'_, f64>,
1041 ) -> SpatialResult<Vec<f64>> {
1042 let mut all_distances = Vec::new();
1043
1044 for node_arc in &self.nodes {
1045 let node = node_arc.read().await;
1046 if let Some(ref local_data) = node.local_data {
1047 for point in local_data.outer_iter() {
1048 let mut min_distance = f64::INFINITY;
1049
1050 for centroid in centroids.outer_iter() {
1051 let distance: f64 = point
1052 .iter()
1053 .zip(centroid.iter())
1054 .map(|(&a, &b)| (a - b).powi(2))
1055 .sum::<f64>()
1056 .sqrt();
1057
1058 min_distance = min_distance.min(distance);
1059 }
1060
1061 all_distances.push(min_distance);
1062 }
1063 }
1064 }
1065
1066 Ok(all_distances)
1067 }
1068
1069 async fn select_next_centroid_weighted(
1071 &self,
1072 _distances: &[f64],
1073 ) -> SpatialResult<Array1<f64>> {
1074 let total_distance: f64 = _distances.iter().sum();
1075 let target = random_f64() * total_distance;
1076
1077 let mut cumulative = 0.0;
1078 let mut point_index = 0;
1079
1080 for &distance in _distances {
1081 cumulative += distance;
1082 if cumulative >= target {
1083 break;
1084 }
1085 point_index += 1;
1086 }
1087
1088 let mut current_index = 0;
1090 for node_arc in &self.nodes {
1091 let node = node_arc.read().await;
1092 if let Some(ref local_data) = node.local_data {
1093 if current_index + local_data.nrows() > point_index {
1094 let local_index = point_index - current_index;
1095 return Ok(local_data.row(local_index).to_owned());
1096 }
1097 current_index += local_data.nrows();
1098 }
1099 }
1100
1101 Err(SpatialError::InvalidInput(
1102 "Point index out of range".to_string(),
1103 ))
1104 }
1105
1106 async fn distributed_assignment_step(
1108 &self,
1109 centroids: &Array2<f64>,
1110 ) -> SpatialResult<Vec<(usize, Array1<usize>)>> {
1111 let mut local_assignments = Vec::new();
1112
1113 for (node_id, node_arc) in self.nodes.iter().enumerate() {
1114 let node = node_arc.read().await;
1115 if let Some(ref local_data) = node.local_data {
1116 let (n_points_, _) = local_data.dim();
1117 let mut assignments = Array1::zeros(n_points_);
1118
1119 for (i, point) in local_data.outer_iter().enumerate() {
1120 let mut best_cluster = 0;
1121 let mut best_distance = f64::INFINITY;
1122
1123 for (j, centroid) in centroids.outer_iter().enumerate() {
1124 let distance: f64 = point
1125 .iter()
1126 .zip(centroid.iter())
1127 .map(|(&a, &b)| (a - b).powi(2))
1128 .sum::<f64>()
1129 .sqrt();
1130
1131 if distance < best_distance {
1132 best_distance = distance;
1133 best_cluster = j;
1134 }
1135 }
1136
1137 assignments[i] = best_cluster;
1138 }
1139
1140 local_assignments.push((node_id, assignments));
1141 }
1142 }
1143
1144 Ok(local_assignments)
1145 }
1146
1147 async fn distributed_centroid_update(
1149 &self,
1150 local_assignments: &[(usize, Array1<usize>)],
1151 k: usize,
1152 ) -> SpatialResult<Array2<f64>> {
1153 let mut cluster_sums: HashMap<usize, Array1<f64>> = HashMap::new();
1155 let mut cluster_counts: HashMap<usize, usize> = HashMap::new();
1156
1157 for (node_id, assignments) in local_assignments {
1158 let node = self.nodes[*node_id].read().await;
1159 if let Some(ref local_data) = node.local_data {
1160 let (_, n_dims) = local_data.dim();
1161
1162 for (i, &cluster) in assignments.iter().enumerate() {
1163 let point = local_data.row(i);
1164
1165 let cluster_sum = cluster_sums
1166 .entry(cluster)
1167 .or_insert_with(|| Array1::zeros(n_dims));
1168 let cluster_count = cluster_counts.entry(cluster).or_insert(0);
1169
1170 for (j, &coord) in point.iter().enumerate() {
1171 cluster_sum[j] += coord;
1172 }
1173 *cluster_count += 1;
1174 }
1175 }
1176 }
1177
1178 let n_dims = cluster_sums
1180 .values()
1181 .next()
1182 .map(|sum| sum.len())
1183 .unwrap_or(2);
1184
1185 let mut new_centroids = Array2::zeros((k, n_dims));
1186
1187 for cluster in 0..k {
1188 if let (Some(sum), Some(&count)) =
1189 (cluster_sums.get(&cluster), cluster_counts.get(&cluster))
1190 {
1191 if count > 0 {
1192 for j in 0..n_dims {
1193 new_centroids[[cluster, j]] = sum[j] / count as f64;
1194 }
1195 }
1196 }
1197 }
1198
1199 Ok(new_centroids)
1200 }
1201
1202 fn calculate_centroid_change(
1204 &self,
1205 old_centroids: &Array2<f64>,
1206 new_centroids: &Array2<f64>,
1207 ) -> f64 {
1208 let mut total_change = 0.0;
1209
1210 for (old_row, new_row) in old_centroids.outer_iter().zip(new_centroids.outer_iter()) {
1211 let change: f64 = old_row
1212 .iter()
1213 .zip(new_row.iter())
1214 .map(|(&a, &b)| (a - b).powi(2))
1215 .sum::<f64>()
1216 .sqrt();
1217 total_change += change;
1218 }
1219
1220 total_change / old_centroids.nrows() as f64
1221 }
1222
1223 async fn collect_final_assignments(
1225 &self,
1226 centroids: &Array2<f64>,
1227 ) -> SpatialResult<Array1<usize>> {
1228 let mut all_assignments = Vec::new();
1229
1230 for node_arc in &self.nodes {
1231 let node = node_arc.read().await;
1232 if let Some(ref local_data) = node.local_data {
1233 for point in local_data.outer_iter() {
1234 let mut best_cluster = 0;
1235 let mut best_distance = f64::INFINITY;
1236
1237 for (j, centroid) in centroids.outer_iter().enumerate() {
1238 let distance: f64 = point
1239 .iter()
1240 .zip(centroid.iter())
1241 .map(|(&a, &b)| (a - b).powi(2))
1242 .sum::<f64>()
1243 .sqrt();
1244
1245 if distance < best_distance {
1246 best_distance = distance;
1247 best_cluster = j;
1248 }
1249 }
1250
1251 all_assignments.push(best_cluster);
1252 }
1253 }
1254 }
1255
1256 Ok(Array1::from(all_assignments))
1257 }
1258
1259 pub async fn distributed_knn_search(
1261 &self,
1262 query_point: &ArrayView1<'_, f64>,
1263 k: usize,
1264 ) -> SpatialResult<Vec<(usize, f64)>> {
1265 let mut all_neighbors = Vec::new();
1266
1267 for node_arc in &self.nodes {
1269 let node = node_arc.read().await;
1270 if let Some(ref local_index) = node.local_index {
1271 if let Some(ref kdtree) = local_index.local_index.kdtree {
1272 if local_index.local_index.bounds.contains(query_point) {
1274 let (indices, distances) =
1275 kdtree.query(query_point.as_slice().expect("Operation failed"), k)?;
1276
1277 for (idx, dist) in indices.iter().zip(distances.iter()) {
1278 all_neighbors.push((*idx, *dist));
1279 }
1280 }
1281 }
1282 }
1283 }
1284
1285 all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
1287 all_neighbors.truncate(k);
1288
1289 Ok(all_neighbors)
1290 }
1291
1292 pub async fn get_cluster_statistics(&self) -> SpatialResult<ClusterStatistics> {
1294 let state = self.cluster_state.read().await;
1295 let _load_balancer = self.load_balancer.read().await;
1296 let communication = self.communication.read().await;
1297
1298 let active_node_count = state.active_nodes.len();
1299 let total_partitions = state.total_partitions;
1300 let avg_partitions_per_node = if active_node_count > 0 {
1301 total_partitions as f64 / active_node_count as f64
1302 } else {
1303 0.0
1304 };
1305
1306 Ok(ClusterStatistics {
1307 active_nodes: active_node_count,
1308 total_data_points: state.total_data_points,
1309 total_partitions,
1310 avg_partitions_per_node,
1311 health_score: state.health_score,
1312 load_balance_score: state.performance_metrics.load_balance_score,
1313 avg_query_latency_ms: state.performance_metrics.avg_query_latency_ms,
1314 throughput_qps: state.performance_metrics.throughput_qps,
1315 total_messages_sent: communication.stats.messages_sent,
1316 total_bytes_sent: communication.stats.bytes_sent,
1317 avg_communication_latency_ms: communication.stats.average_latency_ms,
1318 })
1319 }
1320}
1321
1322#[derive(Debug, Clone)]
1324pub struct ClusterStatistics {
1325 pub active_nodes: usize,
1326 pub total_data_points: usize,
1327 pub total_partitions: usize,
1328 pub avg_partitions_per_node: f64,
1329 pub health_score: f64,
1330 pub load_balance_score: f64,
1331 pub avg_query_latency_ms: f64,
1332 pub throughput_qps: f64,
1333 pub total_messages_sent: u64,
1334 pub total_bytes_sent: u64,
1335 pub avg_communication_latency_ms: f64,
1336}
1337
1338#[cfg(test)]
1339mod tests {
1340 use super::*;
1341 use scirs2_core::ndarray::array;
1342
1343 #[test]
1344 fn test_nodeconfig() {
1345 let config = NodeConfig::new()
1346 .with_node_count(4)
1347 .with_fault_tolerance(true)
1348 .with_load_balancing(true);
1349
1350 assert_eq!(config.node_count, 4);
1351 assert!(config.fault_tolerance);
1352 assert!(config.load_balancing);
1353 assert_eq!(config.replication_factor, 2);
1354 }
1355
1356 #[test]
1357 fn test_spatial_bounds() {
1358 let bounds = SpatialBounds {
1359 min_coords: array![0.0, 0.0],
1360 max_coords: array![1.0, 1.0],
1361 };
1362
1363 assert!(bounds.contains(&array![0.5, 0.5].view()));
1364 assert!(!bounds.contains(&array![1.5, 0.5].view()));
1365 assert_eq!(bounds.volume(), 1.0);
1366 }
1367
1368 #[test]
1369 fn test_load_metrics() {
1370 let metrics = LoadMetrics {
1371 cpu_utilization: 0.5,
1372 memory_utilization: 0.3,
1373 network_utilization: 0.2,
1374 partition_count: 2,
1375 operation_count: 100,
1376 last_update: Instant::now(),
1377 };
1378
1379 let load_score = metrics.load_score();
1380 assert!(load_score > 0.0 && load_score < 1.0);
1381 }
1382
1383 #[cfg(feature = "async")]
1384 #[tokio::test]
1385 async fn test_distributed_cluster_creation() {
1386 let config = NodeConfig::new()
1387 .with_node_count(2)
1388 .with_fault_tolerance(false);
1389
1390 let cluster = DistributedSpatialCluster::new(config);
1391 assert!(cluster.is_ok());
1392
1393 let cluster = cluster.expect("Operation failed");
1394 assert_eq!(cluster.nodes.len(), 2);
1395 assert_eq!(cluster.master_node_id, 0);
1396 }
1397
1398 #[cfg(feature = "async")]
1399 #[tokio::test]
1400 async fn test_data_distribution() {
1401 let config = NodeConfig::new()
1402 .with_node_count(2)
1403 .with_fault_tolerance(false);
1404
1405 let mut cluster = DistributedSpatialCluster::new(config).expect("Operation failed");
1406 let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
1407
1408 let result = cluster.distribute_data(&data.view()).await;
1409 assert!(result.is_ok());
1410
1411 let stats = cluster
1412 .get_cluster_statistics()
1413 .await
1414 .expect("Operation failed");
1415 assert_eq!(stats.total_data_points, 4);
1416 assert!(stats.total_partitions > 0);
1417 }
1418
1419 #[cfg(feature = "async")]
1420 #[tokio::test]
1421 async fn test_distributed_kmeans() {
1422 let config = NodeConfig::new().with_node_count(2);
1423 let mut cluster = DistributedSpatialCluster::new(config).expect("Operation failed");
1424
1425 let data = array![
1426 [0.0, 0.0],
1427 [1.0, 0.0],
1428 [0.0, 1.0],
1429 [1.0, 1.0],
1430 [10.0, 10.0],
1431 [11.0, 10.0]
1432 ];
1433 cluster
1434 .distribute_data(&data.view())
1435 .await
1436 .expect("Operation failed");
1437
1438 let result = cluster.distributed_kmeans(2, 10).await;
1439 assert!(result.is_ok());
1440
1441 let (centroids, assignments) = result.expect("Operation failed");
1442 assert_eq!(centroids.nrows(), 2);
1443 assert_eq!(assignments.len(), 6);
1444 }
1445}