1use crate::error::{SpatialError, SpatialResult};
60use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
61use scirs2_core::random::quick::random_f64;
62use std::collections::{BTreeMap, HashMap, VecDeque};
63use std::sync::Arc;
64use std::time::{Duration, Instant};
65use tokio::sync::{mpsc, RwLock as TokioRwLock};
66
67#[derive(Debug, Clone)]
69pub struct NodeConfig {
70 pub node_count: usize,
72 pub fault_tolerance: bool,
74 pub load_balancing: bool,
76 pub compression: bool,
78 pub communication_timeout_ms: u64,
80 pub heartbeat_interval_ms: u64,
82 pub max_retries: usize,
84 pub replication_factor: usize,
86}
87
88impl Default for NodeConfig {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl NodeConfig {
95 pub fn new() -> Self {
97 Self {
98 node_count: 1,
99 fault_tolerance: false,
100 load_balancing: false,
101 compression: false,
102 communication_timeout_ms: 5000,
103 heartbeat_interval_ms: 1000,
104 max_retries: 3,
105 replication_factor: 1,
106 }
107 }
108
109 pub fn with_node_count(mut self, count: usize) -> Self {
111 self.node_count = count;
112 self
113 }
114
115 pub fn with_fault_tolerance(mut self, enabled: bool) -> Self {
117 self.fault_tolerance = enabled;
118 if enabled && self.replication_factor < 2 {
119 self.replication_factor = 2;
120 }
121 self
122 }
123
124 pub fn with_load_balancing(mut self, enabled: bool) -> Self {
126 self.load_balancing = enabled;
127 self
128 }
129
130 pub fn with_compression(mut self, enabled: bool) -> Self {
132 self.compression = enabled;
133 self
134 }
135}
136
137#[derive(Debug)]
139pub struct DistributedSpatialCluster {
140 config: NodeConfig,
142 nodes: Vec<Arc<TokioRwLock<NodeInstance>>>,
144 #[allow(dead_code)]
146 master_node_id: usize,
147 partitions: Arc<TokioRwLock<HashMap<usize, DataPartition>>>,
149 load_balancer: Arc<TokioRwLock<LoadBalancer>>,
151 #[allow(dead_code)]
153 fault_detector: Arc<TokioRwLock<FaultDetector>>,
154 communication: Arc<TokioRwLock<CommunicationLayer>>,
156 cluster_state: Arc<TokioRwLock<ClusterState>>,
158}
159
160#[derive(Debug)]
162pub struct NodeInstance {
163 pub node_id: usize,
165 pub status: NodeStatus,
167 pub local_data: Option<Array2<f64>>,
169 pub local_index: Option<DistributedSpatialIndex>,
171 pub load_metrics: LoadMetrics,
173 pub last_heartbeat: Instant,
175 pub assigned_partitions: Vec<usize>,
177}
178
179#[derive(Debug, Clone, PartialEq)]
181pub enum NodeStatus {
182 Active,
183 Inactive,
184 Failed,
185 Recovering,
186 Joining,
187 Leaving,
188}
189
190#[derive(Debug, Clone)]
192pub struct DataPartition {
193 pub partition_id: usize,
195 pub bounds: SpatialBounds,
197 pub data: Array2<f64>,
199 pub primary_node: usize,
201 pub replica_nodes: Vec<usize>,
203 pub size: usize,
205 pub last_modified: Instant,
207}
208
209#[derive(Debug, Clone)]
211pub struct SpatialBounds {
212 pub min_coords: Array1<f64>,
214 pub max_coords: Array1<f64>,
216}
217
218impl SpatialBounds {
219 pub fn contains(&self, point: &ArrayView1<f64>) -> bool {
221 point
222 .iter()
223 .zip(self.min_coords.iter())
224 .zip(self.max_coords.iter())
225 .all(|((&coord, &min_coord), &max_coord)| coord >= min_coord && coord <= max_coord)
226 }
227
228 pub fn volume(&self) -> f64 {
230 self.min_coords
231 .iter()
232 .zip(self.max_coords.iter())
233 .map(|(&min_coord, &max_coord)| max_coord - min_coord)
234 .product()
235 }
236}
237
238#[derive(Debug)]
240pub struct LoadBalancer {
241 #[allow(dead_code)]
243 node_loads: HashMap<usize, LoadMetrics>,
244 #[allow(dead_code)]
246 strategy: LoadBalancingStrategy,
247 #[allow(dead_code)]
249 last_rebalance: Instant,
250 #[allow(dead_code)]
252 load_threshold: f64,
253}
254
255#[derive(Debug, Clone)]
257pub enum LoadBalancingStrategy {
258 RoundRobin,
259 LeastLoaded,
260 ProportionalLoad,
261 AdaptiveLoad,
262}
263
264#[derive(Debug, Clone)]
266pub struct LoadMetrics {
267 pub cpu_utilization: f64,
269 pub memory_utilization: f64,
271 pub network_utilization: f64,
273 pub partition_count: usize,
275 pub operation_count: usize,
277 pub last_update: Instant,
279}
280
281impl LoadMetrics {
282 pub fn load_score(&self) -> f64 {
284 0.4 * self.cpu_utilization
285 + 0.3 * self.memory_utilization
286 + 0.2 * self.network_utilization
287 + 0.1 * (self.partition_count as f64 / 10.0).min(1.0)
288 }
289}
290
291#[derive(Debug)]
293pub struct FaultDetector {
294 #[allow(dead_code)]
296 node_health: HashMap<usize, NodeHealth>,
297 #[allow(dead_code)]
299 failure_threshold: Duration,
300 #[allow(dead_code)]
302 recovery_strategies: HashMap<FailureType, RecoveryStrategy>,
303}
304
305#[derive(Debug, Clone)]
307pub struct NodeHealth {
308 pub last_contact: Instant,
310 pub consecutive_failures: usize,
312 pub response_times: VecDeque<Duration>,
314 pub health_score: f64,
316}
317
318#[derive(Debug, Clone, Hash, PartialEq, Eq)]
320pub enum FailureType {
321 NodeUnresponsive,
322 HighLatency,
323 ResourceExhaustion,
324 PartialFailure,
325 NetworkPartition,
326}
327
328#[derive(Debug, Clone)]
330pub enum RecoveryStrategy {
331 Restart,
332 Migrate,
333 Replicate,
334 Isolate,
335 WaitAndRetry,
336}
337
338#[derive(Debug)]
340pub struct CommunicationLayer {
341 #[allow(dead_code)]
343 channels: HashMap<usize, mpsc::Sender<DistributedMessage>>,
344 #[allow(dead_code)]
346 compression_enabled: bool,
347 stats: CommunicationStats,
349}
350
351#[derive(Debug, Clone)]
353pub struct CommunicationStats {
354 pub messages_sent: u64,
356 pub messages_received: u64,
358 pub bytes_sent: u64,
360 pub bytes_received: u64,
362 pub average_latency_ms: f64,
364}
365
366#[derive(Debug, Clone)]
368pub enum DistributedMessage {
369 Heartbeat {
371 node_id: usize,
372 timestamp: Instant,
373 load_metrics: LoadMetrics,
374 },
375 DataDistribution {
377 partition_id: usize,
378 data: Array2<f64>,
379 bounds: SpatialBounds,
380 },
381 Query {
383 query_id: usize,
384 query_type: QueryType,
385 parameters: QueryParameters,
386 },
387 QueryResponse {
389 query_id: usize,
390 results: QueryResults,
391 node_id: usize,
392 },
393 LoadBalance { rebalance_plan: RebalancePlan },
395 FaultTolerance {
397 failure_type: FailureType,
398 affected_nodes: Vec<usize>,
399 recovery_plan: RecoveryPlan,
400 },
401}
402
403#[derive(Debug, Clone)]
405pub enum QueryType {
406 KNearestNeighbors,
407 RangeSearch,
408 Clustering,
409 DistanceMatrix,
410}
411
412#[derive(Debug, Clone)]
414pub struct QueryParameters {
415 pub query_point: Option<Array1<f64>>,
417 pub radius: Option<f64>,
419 pub k: Option<usize>,
421 pub num_clusters: Option<usize>,
423 pub extra_params: HashMap<String, f64>,
425}
426
427#[derive(Debug, Clone)]
429pub enum QueryResults {
430 NearestNeighbors {
431 indices: Vec<usize>,
432 distances: Vec<f64>,
433 },
434 RangeSearch {
435 indices: Vec<usize>,
436 points: Array2<f64>,
437 },
438 Clustering {
439 centroids: Array2<f64>,
440 assignments: Array1<usize>,
441 },
442 DistanceMatrix {
443 matrix: Array2<f64>,
444 },
445}
446
447#[derive(Debug, Clone)]
449pub struct RebalancePlan {
450 pub migrations: Vec<PartitionMigration>,
452 pub load_improvement: f64,
454 pub migration_cost: f64,
456}
457
458#[derive(Debug, Clone)]
460pub struct PartitionMigration {
461 pub partition_id: usize,
463 pub from_node: usize,
465 pub to_node: usize,
467 pub priority: f64,
469}
470
471#[derive(Debug, Clone)]
473pub struct RecoveryPlan {
474 pub actions: Vec<RecoveryAction>,
476 pub estimated_recovery_time: Duration,
478 pub success_probability: f64,
480}
481
482#[derive(Debug, Clone)]
484pub struct RecoveryAction {
485 pub action_type: RecoveryStrategy,
487 pub target_node: usize,
489 pub parameters: HashMap<String, String>,
491}
492
493#[derive(Debug)]
495pub struct ClusterState {
496 pub active_nodes: Vec<usize>,
498 pub total_data_points: usize,
500 pub total_partitions: usize,
502 pub health_score: f64,
504 pub performance_metrics: ClusterPerformanceMetrics,
506}
507
508#[derive(Debug, Clone)]
510pub struct ClusterPerformanceMetrics {
511 pub avg_query_latency_ms: f64,
513 pub throughput_qps: f64,
515 pub load_balance_score: f64,
517 pub fault_tolerance_level: f64,
519}
520
521#[derive(Debug)]
523pub struct DistributedSpatialIndex {
524 pub local_index: LocalSpatialIndex,
526 pub global_metadata: GlobalIndexMetadata,
528 pub routing_table: RoutingTable,
530}
531
532#[derive(Debug)]
534pub struct LocalSpatialIndex {
535 pub kdtree: Option<crate::KDTree<f64, crate::EuclideanDistance<f64>>>,
537 pub bounds: SpatialBounds,
539 pub stats: IndexStatistics,
541}
542
543#[derive(Debug, Clone)]
545pub struct GlobalIndexMetadata {
546 pub global_bounds: SpatialBounds,
548 pub partition_map: HashMap<usize, SpatialBounds>,
550 pub version: usize,
552}
553
554#[derive(Debug)]
556pub struct RoutingTable {
557 pub entries: BTreeMap<SpatialKey, Vec<usize>>,
559 pub cache: HashMap<SpatialKey, Vec<usize>>,
561}
562
563#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
565pub struct SpatialKey {
566 pub z_order: u64,
568 pub level: usize,
570}
571
572#[derive(Debug, Clone)]
574pub struct IndexStatistics {
575 pub build_time_ms: f64,
577 pub memory_usage_bytes: usize,
579 pub query_count: u64,
581 pub avg_query_time_ms: f64,
583}
584
585impl DistributedSpatialCluster {
586 pub fn new(config: NodeConfig) -> SpatialResult<Self> {
588 let mut nodes = Vec::new();
589 let mut channels = HashMap::new();
590
591 for node_id in 0..config.node_count {
593 let (sender, receiver) = mpsc::channel(1000);
594 channels.insert(node_id, sender);
595
596 let node = NodeInstance {
597 node_id,
598 status: NodeStatus::Active,
599 local_data: None,
600 local_index: None,
601 load_metrics: LoadMetrics {
602 cpu_utilization: 0.0,
603 memory_utilization: 0.0,
604 network_utilization: 0.0,
605 partition_count: 0,
606 operation_count: 0,
607 last_update: Instant::now(),
608 },
609 last_heartbeat: Instant::now(),
610 assigned_partitions: Vec::new(),
611 };
612
613 nodes.push(Arc::new(TokioRwLock::new(node)));
614 }
615
616 let load_balancer = LoadBalancer {
617 node_loads: HashMap::new(),
618 strategy: LoadBalancingStrategy::AdaptiveLoad,
619 last_rebalance: Instant::now(),
620 load_threshold: 0.8,
621 };
622
623 let fault_detector = FaultDetector {
624 node_health: HashMap::new(),
625 failure_threshold: Duration::from_secs(10),
626 recovery_strategies: HashMap::new(),
627 };
628
629 let communication = CommunicationLayer {
630 channels,
631 compression_enabled: config.compression,
632 stats: CommunicationStats {
633 messages_sent: 0,
634 messages_received: 0,
635 bytes_sent: 0,
636 bytes_received: 0,
637 average_latency_ms: 0.0,
638 },
639 };
640
641 let cluster_state = ClusterState {
642 active_nodes: (0..config.node_count).collect(),
643 total_data_points: 0,
644 total_partitions: 0,
645 health_score: 1.0,
646 performance_metrics: ClusterPerformanceMetrics {
647 avg_query_latency_ms: 0.0,
648 throughput_qps: 0.0,
649 load_balance_score: 1.0,
650 fault_tolerance_level: if config.fault_tolerance { 0.8 } else { 0.0 },
651 },
652 };
653
654 Ok(Self {
655 config,
656 nodes,
657 master_node_id: 0,
658 partitions: Arc::new(TokioRwLock::new(HashMap::new())),
659 load_balancer: Arc::new(TokioRwLock::new(load_balancer)),
660 fault_detector: Arc::new(TokioRwLock::new(fault_detector)),
661 communication: Arc::new(TokioRwLock::new(communication)),
662 cluster_state: Arc::new(TokioRwLock::new(cluster_state)),
663 })
664 }
665
666 #[allow(dead_code)]
668 fn default_recovery_strategies(&self) -> HashMap<FailureType, RecoveryStrategy> {
669 let mut strategies = HashMap::new();
670 strategies.insert(FailureType::NodeUnresponsive, RecoveryStrategy::Restart);
671 strategies.insert(FailureType::HighLatency, RecoveryStrategy::WaitAndRetry);
672 strategies.insert(FailureType::ResourceExhaustion, RecoveryStrategy::Migrate);
673 strategies.insert(FailureType::PartialFailure, RecoveryStrategy::Replicate);
674 strategies.insert(FailureType::NetworkPartition, RecoveryStrategy::Isolate);
675 strategies
676 }
677
678 pub async fn distribute_data(&mut self, data: &ArrayView2<'_, f64>) -> SpatialResult<()> {
680 let (n_points, n_dims) = data.dim();
681
682 let partitions = self.create_spatial_partitions(data).await?;
684
685 self.assign_partitions_to_nodes(&partitions).await?;
687
688 self.build_distributed_indices().await?;
690
691 {
693 let mut state = self.cluster_state.write().await;
694 state.total_data_points = n_points;
695 state.total_partitions = partitions.len();
696 }
697
698 Ok(())
699 }
700
701 async fn create_spatial_partitions(
703 &self,
704 data: &ArrayView2<'_, f64>,
705 ) -> SpatialResult<Vec<DataPartition>> {
706 let (n_points, n_dims) = data.dim();
707 let target_partitions = self.config.node_count * 2; let mut min_coords = Array1::from_elem(n_dims, f64::INFINITY);
711 let mut max_coords = Array1::from_elem(n_dims, f64::NEG_INFINITY);
712
713 for point in data.outer_iter() {
714 for (i, &coord) in point.iter().enumerate() {
715 min_coords[i] = min_coords[i].min(coord);
716 max_coords[i] = max_coords[i].max(coord);
717 }
718 }
719
720 let global_bounds = SpatialBounds {
721 min_coords,
722 max_coords,
723 };
724
725 let mut point_z_orders = Vec::new();
727 for (i, point) in data.outer_iter().enumerate() {
728 let z_order = self.calculate_z_order(&point.to_owned(), &global_bounds, 16);
729 point_z_orders.push((i, z_order, point.to_owned()));
730 }
731
732 point_z_orders.sort_by_key(|(_, z_order_, _)| *z_order_);
734
735 let points_per_partition = n_points.div_ceil(target_partitions);
737 let mut partitions = Vec::new();
738
739 for partition_id in 0..target_partitions {
740 let start_idx = partition_id * points_per_partition;
741 let end_idx = ((partition_id + 1) * points_per_partition).min(n_points);
742
743 if start_idx >= n_points {
744 break;
745 }
746
747 let partition_size = end_idx - start_idx;
749 let mut partition_data = Array2::zeros((partition_size, n_dims));
750 let mut partition_min = Array1::from_elem(n_dims, f64::INFINITY);
751 let mut partition_max = Array1::from_elem(n_dims, f64::NEG_INFINITY);
752
753 for (i, (_, _, point)) in point_z_orders[start_idx..end_idx].iter().enumerate() {
754 partition_data.row_mut(i).assign(point);
755
756 for (j, &coord) in point.iter().enumerate() {
757 partition_min[j] = partition_min[j].min(coord);
758 partition_max[j] = partition_max[j].max(coord);
759 }
760 }
761
762 let partition_bounds = SpatialBounds {
763 min_coords: partition_min,
764 max_coords: partition_max,
765 };
766
767 let partition = DataPartition {
768 partition_id,
769 bounds: partition_bounds,
770 data: partition_data,
771 primary_node: partition_id % self.config.node_count,
772 replica_nodes: if self.config.fault_tolerance {
773 vec![(partition_id + 1) % self.config.node_count]
774 } else {
775 Vec::new()
776 },
777 size: partition_size,
778 last_modified: Instant::now(),
779 };
780
781 partitions.push(partition);
782 }
783
784 Ok(partitions)
785 }
786
787 fn calculate_z_order(
789 &self,
790 point: &Array1<f64>,
791 bounds: &SpatialBounds,
792 resolution: usize,
793 ) -> u64 {
794 let mut z_order = 0u64;
795
796 for bit in 0..resolution {
797 for (dim, ((&coord, &min_coord), &max_coord)) in point
798 .iter()
799 .zip(bounds.min_coords.iter())
800 .zip(bounds.max_coords.iter())
801 .enumerate()
802 {
803 if dim >= 3 {
804 break;
805 } let normalized = if max_coord > min_coord {
808 (coord - min_coord) / (max_coord - min_coord)
809 } else {
810 0.5
811 };
812
813 let bit_val = if normalized >= 0.5 { 1u64 } else { 0u64 };
814 let bit_pos = bit * 3 + dim; if bit_pos < 64 {
817 z_order |= bit_val << bit_pos;
818 }
819 }
820 }
821
822 z_order
823 }
824
825 async fn assign_partitions_to_nodes(
827 &mut self,
828 partitions: &[DataPartition],
829 ) -> SpatialResult<()> {
830 let mut partition_map = HashMap::new();
831
832 for partition in partitions {
833 partition_map.insert(partition.partition_id, partition.clone());
834
835 let primary_node = &self.nodes[partition.primary_node];
837 {
838 let mut node = primary_node.write().await;
839 node.assigned_partitions.push(partition.partition_id);
840
841 if let Some(ref existing_data) = node.local_data {
843 let (existing_rows, cols) = existing_data.dim();
845 let (new_rows_, _) = partition.data.dim();
846 let total_rows = existing_rows + new_rows_;
847
848 let mut combined_data = Array2::zeros((total_rows, cols));
849 combined_data
850 .slice_mut(s![..existing_rows, ..])
851 .assign(existing_data);
852 combined_data
853 .slice_mut(s![existing_rows.., ..])
854 .assign(&partition.data);
855 node.local_data = Some(combined_data);
856 } else {
857 node.local_data = Some(partition.data.clone());
858 }
859
860 node.load_metrics.partition_count += 1;
861 }
862
863 for &replica_node_id in &partition.replica_nodes {
865 let replica_node = &self.nodes[replica_node_id];
866 let mut node = replica_node.write().await;
867 node.assigned_partitions.push(partition.partition_id);
868
869 if let Some(ref existing_data) = node.local_data {
871 let (existing_rows, cols) = existing_data.dim();
873 let (new_rows_, _) = partition.data.dim();
874 let total_rows = existing_rows + new_rows_;
875
876 let mut combined_data = Array2::zeros((total_rows, cols));
877 combined_data
878 .slice_mut(s![..existing_rows, ..])
879 .assign(existing_data);
880 combined_data
881 .slice_mut(s![existing_rows.., ..])
882 .assign(&partition.data);
883 node.local_data = Some(combined_data);
884 } else {
885 node.local_data = Some(partition.data.clone());
886 }
887
888 node.load_metrics.partition_count += 1;
889 }
890 }
891
892 {
893 let mut partitions_lock = self.partitions.write().await;
894 *partitions_lock = partition_map;
895 }
896
897 Ok(())
898 }
899
900 async fn build_distributed_indices(&mut self) -> SpatialResult<()> {
902 for node_arc in &self.nodes {
904 let mut node = node_arc.write().await;
905
906 if let Some(ref local_data) = node.local_data {
907 let (n_points, n_dims) = local_data.dim();
909 let mut min_coords = Array1::from_elem(n_dims, f64::INFINITY);
910 let mut max_coords = Array1::from_elem(n_dims, f64::NEG_INFINITY);
911
912 for point in local_data.outer_iter() {
913 for (i, &coord) in point.iter().enumerate() {
914 min_coords[i] = min_coords[i].min(coord);
915 max_coords[i] = max_coords[i].max(coord);
916 }
917 }
918
919 let local_bounds = SpatialBounds {
920 min_coords,
921 max_coords,
922 };
923
924 let kdtree = crate::KDTree::new(local_data)?;
926
927 let local_index = LocalSpatialIndex {
928 kdtree: Some(kdtree),
929 bounds: local_bounds.clone(),
930 stats: IndexStatistics {
931 build_time_ms: 0.0, memory_usage_bytes: n_points * n_dims * 8, query_count: 0,
934 avg_query_time_ms: 0.0,
935 },
936 };
937
938 let routing_table = RoutingTable {
940 entries: BTreeMap::new(),
941 cache: HashMap::new(),
942 };
943
944 let global_metadata = GlobalIndexMetadata {
946 global_bounds: local_bounds.clone(), partition_map: HashMap::new(),
948 version: 1,
949 };
950
951 let distributed_index = DistributedSpatialIndex {
952 local_index,
953 global_metadata,
954 routing_table,
955 };
956
957 node.local_index = Some(distributed_index);
958 }
959 }
960
961 Ok(())
962 }
963
964 pub async fn distributed_kmeans(
966 &mut self,
967 k: usize,
968 max_iterations: usize,
969 ) -> SpatialResult<(Array2<f64>, Array1<usize>)> {
970 let initial_centroids = self.initialize_distributed_centroids(k).await?;
972 let mut centroids = initial_centroids;
973
974 for _iteration in 0..max_iterations {
975 let local_assignments = self.distributed_assignment_step(¢roids).await?;
977
978 let new_centroids = self
980 .distributed_centroid_update(&local_assignments, k)
981 .await?;
982
983 let centroid_change = self.calculate_centroid_change(¢roids, &new_centroids);
985 if centroid_change < 1e-6 {
986 break;
987 }
988
989 centroids = new_centroids;
990 }
991
992 let final_assignments = self.collect_final_assignments(¢roids).await?;
994
995 Ok((centroids, final_assignments))
996 }
997
998 async fn initialize_distributed_centroids(&self, k: usize) -> SpatialResult<Array2<f64>> {
1000 let first_centroid = self.get_random_point_from_cluster().await?;
1002
1003 let n_dims = first_centroid.len();
1004 let mut centroids = Array2::zeros((k, n_dims));
1005 centroids.row_mut(0).assign(&first_centroid);
1006
1007 for i in 1..k {
1009 let distances = self
1010 .compute_distributed_distances(¢roids.slice(s![..i, ..]))
1011 .await?;
1012 let next_centroid = self.select_next_centroid_weighted(&distances).await?;
1013 centroids.row_mut(i).assign(&next_centroid);
1014 }
1015
1016 Ok(centroids)
1017 }
1018
1019 async fn get_random_point_from_cluster(&self) -> SpatialResult<Array1<f64>> {
1021 for node_arc in &self.nodes {
1022 let node = node_arc.read().await;
1023 if let Some(ref local_data) = node.local_data {
1024 if local_data.nrows() > 0 {
1025 let idx = (random_f64() * local_data.nrows() as f64) as usize;
1026 return Ok(local_data.row(idx).to_owned());
1027 }
1028 }
1029 }
1030
1031 Err(SpatialError::InvalidInput(
1032 "No data found in cluster".to_string(),
1033 ))
1034 }
1035
1036 async fn compute_distributed_distances(
1038 &self,
1039 centroids: &ArrayView2<'_, f64>,
1040 ) -> SpatialResult<Vec<f64>> {
1041 let mut all_distances = Vec::new();
1042
1043 for node_arc in &self.nodes {
1044 let node = node_arc.read().await;
1045 if let Some(ref local_data) = node.local_data {
1046 for point in local_data.outer_iter() {
1047 let mut min_distance = f64::INFINITY;
1048
1049 for centroid in centroids.outer_iter() {
1050 let distance: f64 = point
1051 .iter()
1052 .zip(centroid.iter())
1053 .map(|(&a, &b)| (a - b).powi(2))
1054 .sum::<f64>()
1055 .sqrt();
1056
1057 min_distance = min_distance.min(distance);
1058 }
1059
1060 all_distances.push(min_distance);
1061 }
1062 }
1063 }
1064
1065 Ok(all_distances)
1066 }
1067
1068 async fn select_next_centroid_weighted(
1070 &self,
1071 _distances: &[f64],
1072 ) -> SpatialResult<Array1<f64>> {
1073 let total_distance: f64 = _distances.iter().sum();
1074 let target = random_f64() * total_distance;
1075
1076 let mut cumulative = 0.0;
1077 let mut point_index = 0;
1078
1079 for &distance in _distances {
1080 cumulative += distance;
1081 if cumulative >= target {
1082 break;
1083 }
1084 point_index += 1;
1085 }
1086
1087 let mut current_index = 0;
1089 for node_arc in &self.nodes {
1090 let node = node_arc.read().await;
1091 if let Some(ref local_data) = node.local_data {
1092 if current_index + local_data.nrows() > point_index {
1093 let local_index = point_index - current_index;
1094 return Ok(local_data.row(local_index).to_owned());
1095 }
1096 current_index += local_data.nrows();
1097 }
1098 }
1099
1100 Err(SpatialError::InvalidInput(
1101 "Point index out of range".to_string(),
1102 ))
1103 }
1104
1105 async fn distributed_assignment_step(
1107 &self,
1108 centroids: &Array2<f64>,
1109 ) -> SpatialResult<Vec<(usize, Array1<usize>)>> {
1110 let mut local_assignments = Vec::new();
1111
1112 for (node_id, node_arc) in self.nodes.iter().enumerate() {
1113 let node = node_arc.read().await;
1114 if let Some(ref local_data) = node.local_data {
1115 let (n_points_, _) = local_data.dim();
1116 let mut assignments = Array1::zeros(n_points_);
1117
1118 for (i, point) in local_data.outer_iter().enumerate() {
1119 let mut best_cluster = 0;
1120 let mut best_distance = f64::INFINITY;
1121
1122 for (j, centroid) in centroids.outer_iter().enumerate() {
1123 let distance: f64 = point
1124 .iter()
1125 .zip(centroid.iter())
1126 .map(|(&a, &b)| (a - b).powi(2))
1127 .sum::<f64>()
1128 .sqrt();
1129
1130 if distance < best_distance {
1131 best_distance = distance;
1132 best_cluster = j;
1133 }
1134 }
1135
1136 assignments[i] = best_cluster;
1137 }
1138
1139 local_assignments.push((node_id, assignments));
1140 }
1141 }
1142
1143 Ok(local_assignments)
1144 }
1145
1146 async fn distributed_centroid_update(
1148 &self,
1149 local_assignments: &[(usize, Array1<usize>)],
1150 k: usize,
1151 ) -> SpatialResult<Array2<f64>> {
1152 let mut cluster_sums: HashMap<usize, Array1<f64>> = HashMap::new();
1154 let mut cluster_counts: HashMap<usize, usize> = HashMap::new();
1155
1156 for (node_id, assignments) in local_assignments {
1157 let node = self.nodes[*node_id].read().await;
1158 if let Some(ref local_data) = node.local_data {
1159 let (_, n_dims) = local_data.dim();
1160
1161 for (i, &cluster) in assignments.iter().enumerate() {
1162 let point = local_data.row(i);
1163
1164 let cluster_sum = cluster_sums
1165 .entry(cluster)
1166 .or_insert_with(|| Array1::zeros(n_dims));
1167 let cluster_count = cluster_counts.entry(cluster).or_insert(0);
1168
1169 for (j, &coord) in point.iter().enumerate() {
1170 cluster_sum[j] += coord;
1171 }
1172 *cluster_count += 1;
1173 }
1174 }
1175 }
1176
1177 let n_dims = cluster_sums
1179 .values()
1180 .next()
1181 .map(|sum| sum.len())
1182 .unwrap_or(2);
1183
1184 let mut new_centroids = Array2::zeros((k, n_dims));
1185
1186 for cluster in 0..k {
1187 if let (Some(sum), Some(&count)) =
1188 (cluster_sums.get(&cluster), cluster_counts.get(&cluster))
1189 {
1190 if count > 0 {
1191 for j in 0..n_dims {
1192 new_centroids[[cluster, j]] = sum[j] / count as f64;
1193 }
1194 }
1195 }
1196 }
1197
1198 Ok(new_centroids)
1199 }
1200
1201 fn calculate_centroid_change(
1203 &self,
1204 old_centroids: &Array2<f64>,
1205 new_centroids: &Array2<f64>,
1206 ) -> f64 {
1207 let mut total_change = 0.0;
1208
1209 for (old_row, new_row) in old_centroids.outer_iter().zip(new_centroids.outer_iter()) {
1210 let change: f64 = old_row
1211 .iter()
1212 .zip(new_row.iter())
1213 .map(|(&a, &b)| (a - b).powi(2))
1214 .sum::<f64>()
1215 .sqrt();
1216 total_change += change;
1217 }
1218
1219 total_change / old_centroids.nrows() as f64
1220 }
1221
1222 async fn collect_final_assignments(
1224 &self,
1225 centroids: &Array2<f64>,
1226 ) -> SpatialResult<Array1<usize>> {
1227 let mut all_assignments = Vec::new();
1228
1229 for node_arc in &self.nodes {
1230 let node = node_arc.read().await;
1231 if let Some(ref local_data) = node.local_data {
1232 for point in local_data.outer_iter() {
1233 let mut best_cluster = 0;
1234 let mut best_distance = f64::INFINITY;
1235
1236 for (j, centroid) in centroids.outer_iter().enumerate() {
1237 let distance: f64 = point
1238 .iter()
1239 .zip(centroid.iter())
1240 .map(|(&a, &b)| (a - b).powi(2))
1241 .sum::<f64>()
1242 .sqrt();
1243
1244 if distance < best_distance {
1245 best_distance = distance;
1246 best_cluster = j;
1247 }
1248 }
1249
1250 all_assignments.push(best_cluster);
1251 }
1252 }
1253 }
1254
1255 Ok(Array1::from(all_assignments))
1256 }
1257
1258 pub async fn distributed_knn_search(
1260 &self,
1261 query_point: &ArrayView1<'_, f64>,
1262 k: usize,
1263 ) -> SpatialResult<Vec<(usize, f64)>> {
1264 let mut all_neighbors = Vec::new();
1265
1266 for node_arc in &self.nodes {
1268 let node = node_arc.read().await;
1269 if let Some(ref local_index) = node.local_index {
1270 if let Some(ref kdtree) = local_index.local_index.kdtree {
1271 if local_index.local_index.bounds.contains(query_point) {
1273 let (indices, distances) =
1274 kdtree.query(query_point.as_slice().unwrap(), k)?;
1275
1276 for (idx, dist) in indices.iter().zip(distances.iter()) {
1277 all_neighbors.push((*idx, *dist));
1278 }
1279 }
1280 }
1281 }
1282 }
1283
1284 all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1286 all_neighbors.truncate(k);
1287
1288 Ok(all_neighbors)
1289 }
1290
1291 pub async fn get_cluster_statistics(&self) -> SpatialResult<ClusterStatistics> {
1293 let state = self.cluster_state.read().await;
1294 let _load_balancer = self.load_balancer.read().await;
1295 let communication = self.communication.read().await;
1296
1297 let active_node_count = state.active_nodes.len();
1298 let total_partitions = state.total_partitions;
1299 let avg_partitions_per_node = if active_node_count > 0 {
1300 total_partitions as f64 / active_node_count as f64
1301 } else {
1302 0.0
1303 };
1304
1305 Ok(ClusterStatistics {
1306 active_nodes: active_node_count,
1307 total_data_points: state.total_data_points,
1308 total_partitions,
1309 avg_partitions_per_node,
1310 health_score: state.health_score,
1311 load_balance_score: state.performance_metrics.load_balance_score,
1312 avg_query_latency_ms: state.performance_metrics.avg_query_latency_ms,
1313 throughput_qps: state.performance_metrics.throughput_qps,
1314 total_messages_sent: communication.stats.messages_sent,
1315 total_bytes_sent: communication.stats.bytes_sent,
1316 avg_communication_latency_ms: communication.stats.average_latency_ms,
1317 })
1318 }
1319}
1320
1321#[derive(Debug, Clone)]
1323pub struct ClusterStatistics {
1324 pub active_nodes: usize,
1325 pub total_data_points: usize,
1326 pub total_partitions: usize,
1327 pub avg_partitions_per_node: f64,
1328 pub health_score: f64,
1329 pub load_balance_score: f64,
1330 pub avg_query_latency_ms: f64,
1331 pub throughput_qps: f64,
1332 pub total_messages_sent: u64,
1333 pub total_bytes_sent: u64,
1334 pub avg_communication_latency_ms: f64,
1335}
1336
1337#[cfg(test)]
1338mod tests {
1339 use super::*;
1340 use scirs2_core::ndarray::array;
1341
1342 #[test]
1343 fn test_nodeconfig() {
1344 let config = NodeConfig::new()
1345 .with_node_count(4)
1346 .with_fault_tolerance(true)
1347 .with_load_balancing(true);
1348
1349 assert_eq!(config.node_count, 4);
1350 assert!(config.fault_tolerance);
1351 assert!(config.load_balancing);
1352 assert_eq!(config.replication_factor, 2);
1353 }
1354
1355 #[test]
1356 fn test_spatial_bounds() {
1357 let bounds = SpatialBounds {
1358 min_coords: array![0.0, 0.0],
1359 max_coords: array![1.0, 1.0],
1360 };
1361
1362 assert!(bounds.contains(&array![0.5, 0.5].view()));
1363 assert!(!bounds.contains(&array![1.5, 0.5].view()));
1364 assert_eq!(bounds.volume(), 1.0);
1365 }
1366
1367 #[test]
1368 fn test_load_metrics() {
1369 let metrics = LoadMetrics {
1370 cpu_utilization: 0.5,
1371 memory_utilization: 0.3,
1372 network_utilization: 0.2,
1373 partition_count: 2,
1374 operation_count: 100,
1375 last_update: Instant::now(),
1376 };
1377
1378 let load_score = metrics.load_score();
1379 assert!(load_score > 0.0 && load_score < 1.0);
1380 }
1381
1382 #[tokio::test]
1383 async fn test_distributed_cluster_creation() {
1384 let config = NodeConfig::new()
1385 .with_node_count(2)
1386 .with_fault_tolerance(false);
1387
1388 let cluster = DistributedSpatialCluster::new(config);
1389 assert!(cluster.is_ok());
1390
1391 let cluster = cluster.unwrap();
1392 assert_eq!(cluster.nodes.len(), 2);
1393 assert_eq!(cluster.master_node_id, 0);
1394 }
1395
1396 #[tokio::test]
1397 async fn test_data_distribution() {
1398 let config = NodeConfig::new()
1399 .with_node_count(2)
1400 .with_fault_tolerance(false);
1401
1402 let mut cluster = DistributedSpatialCluster::new(config).unwrap();
1403 let data = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
1404
1405 let result = cluster.distribute_data(&data.view()).await;
1406 assert!(result.is_ok());
1407
1408 let stats = cluster.get_cluster_statistics().await.unwrap();
1409 assert_eq!(stats.total_data_points, 4);
1410 assert!(stats.total_partitions > 0);
1411 }
1412
1413 #[tokio::test]
1414 async fn test_distributed_kmeans() {
1415 let config = NodeConfig::new().with_node_count(2);
1416 let mut cluster = DistributedSpatialCluster::new(config).unwrap();
1417
1418 let data = array![
1419 [0.0, 0.0],
1420 [1.0, 0.0],
1421 [0.0, 1.0],
1422 [1.0, 1.0],
1423 [10.0, 10.0],
1424 [11.0, 10.0]
1425 ];
1426 cluster.distribute_data(&data.view()).await.unwrap();
1427
1428 let result = cluster.distributed_kmeans(2, 10).await;
1429 assert!(result.is_ok());
1430
1431 let (centroids, assignments) = result.unwrap();
1432 assert_eq!(centroids.nrows(), 2);
1433 assert_eq!(assignments.len(), 6);
1434 }
1435}