scirs2_cluster/distributed/
partitioning.rs

1//! Data partitioning strategies for distributed clustering
2//!
3//! This module provides various strategies for partitioning data across
4//! multiple worker nodes in distributed clustering algorithms.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
7use scirs2_core::numeric::{Float, FromPrimitive, Zero};
8use scirs2_core::random::seq::SliceRandom;
9use std::fmt::Debug;
10
11use super::fault_tolerance::DataPartition;
12use crate::error::{ClusteringError, Result};
13use crate::vq::euclidean_distance;
14
15/// Data partitioning coordinator for distributed clustering
16#[derive(Debug)]
17pub struct DataPartitioner<F: Float> {
18    pub config: PartitioningConfig,
19    pub partitions: Vec<DataPartition<F>>,
20    pub partition_stats: PartitioningStatistics,
21}
22
23/// Configuration for data partitioning
24#[derive(Debug, Clone)]
25pub struct PartitioningConfig {
26    pub n_workers: usize,
27    pub strategy: PartitioningStrategy,
28    pub balance_threshold: f64,
29    pub enable_load_balancing: bool,
30    pub min_partition_size: usize,
31    pub max_partition_size: Option<usize>,
32    pub preserve_locality: bool,
33    pub random_seed: Option<u64>,
34}
35
36impl Default for PartitioningConfig {
37    fn default() -> Self {
38        Self {
39            n_workers: 4,
40            strategy: PartitioningStrategy::Random,
41            balance_threshold: 0.1,
42            enable_load_balancing: true,
43            min_partition_size: 100,
44            max_partition_size: None,
45            preserve_locality: false,
46            random_seed: None,
47        }
48    }
49}
50
51/// Available partitioning strategies
52#[derive(Debug, Clone)]
53pub enum PartitioningStrategy {
54    /// Random partitioning
55    Random,
56    /// Round-robin partitioning
57    RoundRobin,
58    /// Stratified partitioning based on preliminary clustering
59    Stratified { n_strata: usize },
60    /// Hash-based partitioning
61    Hash,
62    /// Range-based partitioning
63    Range { feature_index: usize },
64    /// Locality-preserving partitioning
65    LocalityPreserving { similarity_threshold: f64 },
66    /// Custom partitioning with user-defined function
67    Custom,
68}
69
70/// Statistics about the partitioning
71#[derive(Debug, Default)]
72pub struct PartitioningStatistics {
73    pub partition_sizes: Vec<usize>,
74    pub load_balance_score: f64,
75    pub locality_score: f64,
76    pub partitioning_time_ms: u64,
77    pub memory_usage_bytes: usize,
78}
79
80impl<F: Float + FromPrimitive + Debug + Send + Sync> DataPartitioner<F> {
81    /// Create new data partitioner
82    pub fn new(config: PartitioningConfig) -> Self {
83        Self {
84            config,
85            partitions: Vec::new(),
86            partition_stats: PartitioningStatistics::default(),
87        }
88    }
89
90    /// Partition data according to the configured strategy
91    pub fn partition_data(&mut self, data: ArrayView2<F>) -> Result<Vec<DataPartition<F>>> {
92        let start_time = std::time::Instant::now();
93
94        // Calculate target partition sizes
95        let partition_sizes = self.calculate_partition_sizes(data.nrows())?;
96
97        // Apply partitioning strategy
98        let partitions = match &self.config.strategy {
99            PartitioningStrategy::Random => self.random_partition(data, &partition_sizes),
100            PartitioningStrategy::RoundRobin => self.round_robin_partition(data, &partition_sizes),
101            PartitioningStrategy::Stratified { n_strata } => {
102                self.stratified_partition(data, &partition_sizes, *n_strata)
103            }
104            PartitioningStrategy::Hash => self.hash_partition(data, &partition_sizes),
105            PartitioningStrategy::Range { feature_index } => {
106                self.range_partition(data, &partition_sizes, *feature_index)
107            }
108            PartitioningStrategy::LocalityPreserving {
109                similarity_threshold,
110            } => self.locality_preserving_partition(data, &partition_sizes, *similarity_threshold),
111            PartitioningStrategy::Custom => self.custom_partition(data, &partition_sizes),
112        }?;
113
114        // Update statistics
115        let partitioning_time = start_time.elapsed().as_millis() as u64;
116        self.update_statistics(&partitions, partitioning_time);
117
118        self.partitions = partitions.clone();
119        Ok(partitions)
120    }
121
122    /// Calculate target sizes for each partition
123    fn calculate_partition_sizes(&self, totalsize: usize) -> Result<Vec<usize>> {
124        if self.config.n_workers == 0 {
125            return Err(ClusteringError::InvalidInput(
126                "Number of workers must be greater than 0".to_string(),
127            ));
128        }
129
130        let base_size = totalsize / self.config.n_workers;
131        let remainder = totalsize % self.config.n_workers;
132
133        let mut sizes = vec![base_size; self.config.n_workers];
134
135        // Distribute remainder across first few workers
136        for i in 0..remainder {
137            sizes[i] += 1;
138        }
139
140        // Adjust for minimum partition size constraints
141        // If total is less than n_workers * min_partition_size, we can't satisfy the constraint
142        // so we use the calculated sizes instead
143        let effective_min_size = self
144            .config
145            .min_partition_size
146            .min(totalsize / self.config.n_workers + 1);
147
148        for size in &mut sizes {
149            if *size < effective_min_size {
150                *size = effective_min_size;
151            }
152            if let Some(max_size) = self.config.max_partition_size {
153                if *size > max_size {
154                    *size = max_size;
155                }
156            }
157        }
158
159        // Ensure the total doesn't exceed the original totalsize
160        let current_total: usize = sizes.iter().sum();
161        if current_total > totalsize {
162            // Redistribute to match totalsize exactly
163            let mut sizes = vec![totalsize / self.config.n_workers; self.config.n_workers];
164            let remainder = totalsize % self.config.n_workers;
165            for i in 0..remainder {
166                sizes[i] += 1;
167            }
168            return Ok(sizes);
169        }
170
171        Ok(sizes)
172    }
173
174    /// Random partitioning strategy
175    fn random_partition(
176        &self,
177        data: ArrayView2<F>,
178        partition_sizes: &[usize],
179    ) -> Result<Vec<DataPartition<F>>> {
180        let n_samples = data.nrows();
181        let n_workers = self.config.n_workers;
182
183        // Create random permutation of indices
184        let mut indices: Vec<usize> = (0..n_samples).collect();
185        let mut rng = scirs2_core::random::rng();
186        indices.shuffle(&mut rng);
187
188        let mut partitions = Vec::new();
189        let mut start_idx = 0;
190
191        for (worker_id, &partition_size) in partition_sizes.iter().enumerate() {
192            let end_idx = (start_idx + partition_size).min(n_samples);
193
194            if start_idx < end_idx {
195                let mut partition_data = Array2::zeros((end_idx - start_idx, data.ncols()));
196
197                for (i, &data_idx) in indices[start_idx..end_idx].iter().enumerate() {
198                    partition_data.row_mut(i).assign(&data.row(data_idx));
199                }
200
201                partitions.push(DataPartition::new(worker_id, partition_data, worker_id));
202            }
203
204            start_idx = end_idx;
205            if start_idx >= n_samples {
206                break;
207            }
208        }
209
210        Ok(partitions)
211    }
212
213    /// Stratified partitioning using preliminary clustering
214    fn stratified_partition(
215        &self,
216        data: ArrayView2<F>,
217        partition_sizes: &[usize],
218        n_strata: usize,
219    ) -> Result<Vec<DataPartition<F>>> {
220        let n_samples = data.nrows();
221
222        if n_samples < n_strata {
223            // Fall back to random if not enough data
224            return self.random_partition(data, partition_sizes);
225        }
226
227        // Step 1: Perform preliminary clustering to identify strata
228        let strata_assignments = self.identify_strata(data, n_strata)?;
229
230        // Step 2: Group data points by stratum
231        let mut strata_groups: Vec<Vec<usize>> = vec![Vec::new(); n_strata];
232        for (point_idx, &stratum_id) in strata_assignments.iter().enumerate() {
233            strata_groups[stratum_id].push(point_idx);
234        }
235
236        // Step 3: Distribute strata points proportionally to workers
237        let mut worker_assignments: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_workers];
238
239        for stratum_points in strata_groups.iter() {
240            if stratum_points.is_empty() {
241                continue;
242            }
243
244            // Calculate how many points each worker should get from this stratum
245            let total_points = stratum_points.len();
246            let mut distributed = 0;
247
248            for worker_id in 0..self.config.n_workers {
249                let target_size = partition_sizes[worker_id];
250                let current_size = worker_assignments[worker_id].len();
251                let remaining_capacity = target_size.saturating_sub(current_size);
252
253                // Proportional allocation with remaining capacity constraint
254                let total_remaining_capacity: usize = worker_assignments
255                    .iter()
256                    .enumerate()
257                    .skip(worker_id)
258                    .map(|(i, assignments)| partition_sizes[i].saturating_sub(assignments.len()))
259                    .sum();
260
261                let points_for_worker = if total_remaining_capacity == 0 {
262                    0
263                } else {
264                    let proportion = remaining_capacity as f64 / total_remaining_capacity as f64;
265                    let remaining_points = total_points - distributed;
266                    ((remaining_points as f64 * proportion).round() as usize)
267                        .min(remaining_points)
268                        .min(remaining_capacity)
269                };
270
271                // Assign points to this worker
272                let start_idx = distributed;
273                let end_idx = (start_idx + points_for_worker).min(total_points);
274
275                for &point_idx in &stratum_points[start_idx..end_idx] {
276                    worker_assignments[worker_id].push(point_idx);
277                }
278
279                distributed = end_idx;
280
281                if distributed >= total_points {
282                    break;
283                }
284            }
285        }
286
287        // Step 4: Create partitions from worker assignments
288        let mut partitions = Vec::new();
289        for (worker_id, point_indices) in worker_assignments.into_iter().enumerate() {
290            if !point_indices.is_empty() {
291                let mut partition_data = Array2::zeros((point_indices.len(), data.ncols()));
292
293                for (i, &point_idx) in point_indices.iter().enumerate() {
294                    partition_data.row_mut(i).assign(&data.row(point_idx));
295                }
296
297                partitions.push(DataPartition::new(worker_id, partition_data, worker_id));
298            }
299        }
300
301        Ok(partitions)
302    }
303
304    /// Identify strata using simple K-means clustering
305    fn identify_strata(&self, data: ArrayView2<F>, nstrata: usize) -> Result<Array1<usize>> {
306        let n_samples = data.nrows();
307        let n_features = data.ncols();
308
309        // Initialize centroids randomly
310        let mut rng = scirs2_core::random::rng();
311        let mut point_indices: Vec<usize> = (0..n_samples).collect();
312        point_indices.shuffle(&mut rng);
313
314        let mut centroids = Array2::zeros((nstrata, n_features));
315        for (i, &point_idx) in point_indices.iter().take(nstrata).enumerate() {
316            centroids.row_mut(i).assign(&data.row(point_idx));
317        }
318
319        let mut assignments = Array1::zeros(n_samples);
320        let max_iterations = 10; // Quick preliminary clustering
321
322        for _ in 0..max_iterations {
323            let mut changed = false;
324
325            // Assign points to nearest centroids
326            for (point_idx, point) in data.rows().into_iter().enumerate() {
327                let mut min_dist = F::infinity();
328                let mut best_centroid = 0;
329
330                for (centroid_idx, centroid) in centroids.rows().into_iter().enumerate() {
331                    let dist = euclidean_distance(point, centroid);
332                    if dist < min_dist {
333                        min_dist = dist;
334                        best_centroid = centroid_idx;
335                    }
336                }
337
338                if assignments[point_idx] != best_centroid {
339                    assignments[point_idx] = best_centroid;
340                    changed = true;
341                }
342            }
343
344            if !changed {
345                break;
346            }
347
348            // Update centroids
349            centroids.fill(F::zero());
350            let mut counts = vec![0; nstrata];
351
352            for (point_idx, point) in data.rows().into_iter().enumerate() {
353                let cluster_id = assignments[point_idx];
354                for (j, &value) in point.iter().enumerate() {
355                    centroids[[cluster_id, j]] = centroids[[cluster_id, j]] + value;
356                }
357                counts[cluster_id] += 1;
358            }
359
360            // Compute averages
361            for i in 0..nstrata {
362                if counts[i] > 0 {
363                    for j in 0..n_features {
364                        centroids[[i, j]] = centroids[[i, j]] / F::from(counts[i]).unwrap();
365                    }
366                }
367            }
368        }
369
370        Ok(assignments)
371    }
372
373    /// Round-robin partitioning
374    fn round_robin_partition(
375        &self,
376        data: ArrayView2<F>,
377        _partition_sizes: &[usize],
378    ) -> Result<Vec<DataPartition<F>>> {
379        let n_workers = self.config.n_workers;
380        let mut worker_data: Vec<Vec<usize>> = vec![Vec::new(); n_workers];
381
382        // Assign points in round-robin fashion
383        for (row_idx, _) in data.rows().into_iter().enumerate() {
384            let worker_id = row_idx % n_workers;
385            worker_data[worker_id].push(row_idx);
386        }
387
388        // Create partitions
389        let mut partitions = Vec::new();
390        for (worker_id, row_indices) in worker_data.into_iter().enumerate() {
391            if !row_indices.is_empty() {
392                let mut partition_data = Array2::zeros((row_indices.len(), data.ncols()));
393
394                for (i, &row_idx) in row_indices.iter().enumerate() {
395                    partition_data.row_mut(i).assign(&data.row(row_idx));
396                }
397
398                partitions.push(DataPartition {
399                    partition_id: worker_id,
400                    data: partition_data,
401                    labels: None,
402                    workerid: worker_id,
403                    weight: row_indices.len() as f64 / data.nrows() as f64,
404                });
405            }
406        }
407
408        Ok(partitions)
409    }
410
411    /// Hash-based partitioning using feature hash
412    fn hash_partition(
413        &self,
414        data: ArrayView2<F>,
415        partition_sizes: &[usize],
416    ) -> Result<Vec<DataPartition<F>>> {
417        let n_workers = self.config.n_workers;
418        let mut worker_assignments: Vec<Vec<usize>> = vec![Vec::new(); n_workers];
419
420        // Hash each data point to a worker
421        for (row_idx, row) in data.rows().into_iter().enumerate() {
422            // Simple hash based on first feature (can be improved)
423            let hash_value = if !row.is_empty() {
424                (row[0].to_f64().unwrap_or(0.0) * 1000.0) as u64
425            } else {
426                row_idx as u64
427            };
428
429            let worker_id = (hash_value % n_workers as u64) as usize;
430            worker_assignments[worker_id].push(row_idx);
431        }
432
433        // Create partitions with size balancing
434        let mut partitions = Vec::new();
435        for (worker_id, row_indices) in worker_assignments.into_iter().enumerate() {
436            // Limit partition size if needed
437            let max_size = partition_sizes
438                .get(worker_id)
439                .copied()
440                .unwrap_or(row_indices.len());
441            let actual_indices = if row_indices.len() > max_size {
442                &row_indices[..max_size]
443            } else {
444                &row_indices
445            };
446
447            if !actual_indices.is_empty() {
448                let mut partition_data = Array2::zeros((actual_indices.len(), data.ncols()));
449
450                for (i, &row_idx) in actual_indices.iter().enumerate() {
451                    partition_data.row_mut(i).assign(&data.row(row_idx));
452                }
453
454                partitions.push(DataPartition::new(worker_id, partition_data, worker_id));
455            }
456        }
457
458        Ok(partitions)
459    }
460
461    /// Range-based partitioning on a specific feature
462    fn range_partition(
463        &self,
464        data: ArrayView2<F>,
465        partition_sizes: &[usize],
466        feature_index: usize,
467    ) -> Result<Vec<DataPartition<F>>> {
468        if feature_index >= data.ncols() {
469            return Err(ClusteringError::InvalidInput(
470                "Feature index out of bounds".to_string(),
471            ));
472        }
473
474        // Extract feature values and sort indices
475        let mut indexed_values: Vec<(usize, F)> = data
476            .column(feature_index)
477            .iter()
478            .enumerate()
479            .map(|(i, &val)| (i, val))
480            .collect();
481
482        indexed_values.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
483
484        // Partition based on sorted order
485        let mut partitions = Vec::new();
486        let mut start_idx = 0;
487
488        for (worker_id, &partition_size) in partition_sizes.iter().enumerate() {
489            let end_idx = (start_idx + partition_size).min(indexed_values.len());
490
491            if start_idx < end_idx {
492                let mut partition_data = Array2::zeros((end_idx - start_idx, data.ncols()));
493
494                for (i, &(original_idx, _)) in indexed_values[start_idx..end_idx].iter().enumerate()
495                {
496                    partition_data.row_mut(i).assign(&data.row(original_idx));
497                }
498
499                partitions.push(DataPartition::new(worker_id, partition_data, worker_id));
500            }
501
502            start_idx = end_idx;
503            if start_idx >= indexed_values.len() {
504                break;
505            }
506        }
507
508        Ok(partitions)
509    }
510
511    /// Locality-preserving partitioning based on similarity
512    fn locality_preserving_partition(
513        &self,
514        data: ArrayView2<F>,
515        partition_sizes: &[usize],
516        similarity_threshold: f64,
517    ) -> Result<Vec<DataPartition<F>>> {
518        let n_samples = data.nrows();
519        let mut assigned: Vec<bool> = vec![false; n_samples];
520        let mut worker_assignments: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_workers];
521
522        let mut current_worker = 0;
523        let mut unassigned_points: Vec<usize> = (0..n_samples).collect();
524
525        while !unassigned_points.is_empty() && current_worker < self.config.n_workers {
526            let target_size = partition_sizes[current_worker];
527            let mut current_partition = Vec::new();
528
529            // Start with a random unassigned point
530            if let Some(seed_idx) = unassigned_points.first().copied() {
531                current_partition.push(seed_idx);
532                assigned[seed_idx] = true;
533                unassigned_points.retain(|&x| x != seed_idx);
534
535                // Grow partition by adding similar points
536                while current_partition.len() < target_size && !unassigned_points.is_empty() {
537                    let mut best_similarity = 0.0;
538                    let mut best_candidate = None;
539
540                    // Find most similar unassigned point to any point in current partition
541                    for &candidate_idx in &unassigned_points {
542                        let candidate_point = data.row(candidate_idx);
543
544                        for &partition_point_idx in &current_partition {
545                            let partition_point = data.row(partition_point_idx);
546                            let distance = euclidean_distance(candidate_point, partition_point)
547                                .to_f64()
548                                .unwrap_or(f64::INFINITY);
549                            let similarity = 1.0 / (1.0 + distance); // Convert distance to similarity
550
551                            if similarity > best_similarity && similarity >= similarity_threshold {
552                                best_similarity = similarity;
553                                best_candidate = Some(candidate_idx);
554                            }
555                        }
556                    }
557
558                    if let Some(best_idx) = best_candidate {
559                        current_partition.push(best_idx);
560                        assigned[best_idx] = true;
561                        unassigned_points.retain(|&x| x != best_idx);
562                    } else {
563                        // No similar points found, add random points to fill partition
564                        while current_partition.len() < target_size && !unassigned_points.is_empty()
565                        {
566                            let random_idx = unassigned_points.remove(0);
567                            current_partition.push(random_idx);
568                            assigned[random_idx] = true;
569                        }
570                        break;
571                    }
572                }
573
574                worker_assignments[current_worker] = current_partition;
575            }
576
577            current_worker += 1;
578        }
579
580        // Assign any remaining points to workers with space
581        for remaining_idx in unassigned_points {
582            for worker_id in 0..self.config.n_workers {
583                if worker_assignments[worker_id].len() < partition_sizes[worker_id] {
584                    worker_assignments[worker_id].push(remaining_idx);
585                    break;
586                }
587            }
588        }
589
590        // Create partitions
591        let mut partitions = Vec::new();
592        for (worker_id, point_indices) in worker_assignments.into_iter().enumerate() {
593            if !point_indices.is_empty() {
594                let mut partition_data = Array2::zeros((point_indices.len(), data.ncols()));
595
596                for (i, &point_idx) in point_indices.iter().enumerate() {
597                    partition_data.row_mut(i).assign(&data.row(point_idx));
598                }
599
600                partitions.push(DataPartition::new(worker_id, partition_data, worker_id));
601            }
602        }
603
604        Ok(partitions)
605    }
606
607    /// Custom partitioning (placeholder for user-defined strategies)
608    fn custom_partition(
609        &self,
610        data: ArrayView2<F>,
611        partition_sizes: &[usize],
612    ) -> Result<Vec<DataPartition<F>>> {
613        // Default to random partitioning for custom strategy
614        // In a real implementation, this would allow user-defined partitioning functions
615        self.random_partition(data, partition_sizes)
616    }
617
618    /// Update partitioning statistics
619    fn update_statistics(&mut self, partitions: &[DataPartition<F>], partitioning_timems: u64) {
620        self.partition_stats.partition_sizes = partitions.iter().map(|p| p.data.nrows()).collect();
621        self.partition_stats.partitioning_time_ms = partitioning_timems;
622
623        // Calculate load balance score (1.0 = perfectly balanced, 0.0 = completely imbalanced)
624        if !self.partition_stats.partition_sizes.is_empty() {
625            let avg_size = self.partition_stats.partition_sizes.iter().sum::<usize>() as f64
626                / self.partition_stats.partition_sizes.len() as f64;
627            let variance = self
628                .partition_stats
629                .partition_sizes
630                .iter()
631                .map(|&size| (size as f64 - avg_size).powi(2))
632                .sum::<f64>()
633                / self.partition_stats.partition_sizes.len() as f64;
634
635            self.partition_stats.load_balance_score = if avg_size > 0.0 {
636                1.0 - (variance.sqrt() / avg_size).min(1.0)
637            } else {
638                0.0
639            };
640        }
641
642        // Calculate memory usage (approximate)
643        self.partition_stats.memory_usage_bytes = partitions
644            .iter()
645            .map(|p| p.data.len() * std::mem::size_of::<F>())
646            .sum();
647    }
648
649    /// Get partitioning statistics
650    pub fn get_statistics(&self) -> &PartitioningStatistics {
651        &self.partition_stats
652    }
653
654    /// Get current partitions
655    pub fn get_partitions(&self) -> &[DataPartition<F>] {
656        &self.partitions
657    }
658
659    /// Validate partition balance
660    pub fn validate_partition_balance(&self) -> bool {
661        self.partition_stats.load_balance_score >= (1.0 - self.config.balance_threshold)
662    }
663
664    /// Rebalance partitions if needed
665    pub fn rebalance_if_needed(&mut self, data: ArrayView2<F>) -> Result<bool> {
666        if !self.config.enable_load_balancing || self.validate_partition_balance() {
667            return Ok(false);
668        }
669
670        // Re-partition the data
671        self.partition_data(data)?;
672        Ok(true)
673    }
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679    use scirs2_core::ndarray::Array2;
680
681    #[test]
682    fn test_data_partitioner_creation() {
683        let config = PartitioningConfig::default();
684        let partitioner = DataPartitioner::<f64>::new(config);
685
686        assert_eq!(partitioner.config.n_workers, 4);
687        assert!(partitioner.partitions.is_empty());
688    }
689
690    #[test]
691    fn test_calculate_partition_sizes() {
692        let config = PartitioningConfig {
693            n_workers: 3,
694            min_partition_size: 1, // Set a reasonable min_partition_size for the test
695            ..Default::default()
696        };
697        let partitioner = DataPartitioner::<f64>::new(config);
698
699        let sizes = partitioner.calculate_partition_sizes(100).unwrap();
700        assert_eq!(sizes.len(), 3);
701        assert_eq!(sizes.iter().sum::<usize>(), 100);
702
703        // Should be approximately balanced
704        let max_diff = sizes.iter().max().unwrap() - sizes.iter().min().unwrap();
705        assert!(max_diff <= 1);
706    }
707
708    #[test]
709    fn test_random_partitioning() {
710        let config = PartitioningConfig {
711            n_workers: 2,
712            strategy: PartitioningStrategy::Random,
713            min_partition_size: 1, // Set a reasonable min_partition_size for the test
714            ..Default::default()
715        };
716        let mut partitioner = DataPartitioner::new(config);
717
718        let data = Array2::from_shape_vec((100, 3), (0..300).map(|x| x as f64).collect()).unwrap();
719        let partitions = partitioner.partition_data(data.view()).unwrap();
720
721        assert_eq!(partitions.len(), 2);
722        assert!(partitions.iter().all(|p| p.data.nrows() > 0));
723
724        let total_points: usize = partitions.iter().map(|p| p.data.nrows()).sum();
725        assert_eq!(total_points, 100);
726    }
727
728    #[test]
729    fn test_round_robin_partitioning() {
730        let config = PartitioningConfig {
731            n_workers: 3,
732            strategy: PartitioningStrategy::RoundRobin,
733            ..Default::default()
734        };
735        let mut partitioner = DataPartitioner::new(config);
736
737        let data = Array2::from_shape_vec((99, 2), (0..198).map(|x| x as f64).collect()).unwrap();
738        let partitions = partitioner.partition_data(data.view()).unwrap();
739
740        assert_eq!(partitions.len(), 3);
741        assert_eq!(partitions[0].data.nrows(), 33);
742        assert_eq!(partitions[1].data.nrows(), 33);
743        assert_eq!(partitions[2].data.nrows(), 33);
744    }
745
746    #[test]
747    fn test_load_balance_score() {
748        let config = PartitioningConfig::default();
749        let mut partitioner = DataPartitioner::<f64>::new(config);
750
751        // Perfect balance - create mock partitions
752        let balanced_partitions: Vec<DataPartition<f64>> = (0..4)
753            .map(|i| DataPartition::new(i, Array2::zeros((25, 2)), i))
754            .collect();
755        partitioner.update_statistics(&balanced_partitions, 0);
756        assert!((partitioner.partition_stats.load_balance_score - 1.0).abs() < 0.01);
757
758        // Imbalanced - create imbalanced mock partitions
759        let imbalanced_partitions = vec![
760            DataPartition::new(0, Array2::zeros((10, 2)), 0),
761            DataPartition::new(1, Array2::zeros((90, 2)), 1),
762        ];
763        partitioner.update_statistics(&imbalanced_partitions, 0);
764        assert!(partitioner.partition_stats.load_balance_score < 0.5);
765    }
766
767    #[test]
768    fn test_partition_size_constraints() {
769        let config = PartitioningConfig {
770            n_workers: 3,
771            min_partition_size: 10,
772            max_partition_size: Some(50),
773            ..Default::default()
774        };
775        let partitioner = DataPartitioner::<f64>::new(config);
776
777        let sizes = partitioner.calculate_partition_sizes(120).unwrap();
778        assert!(sizes.iter().all(|&size| size >= 10 && size <= 50));
779    }
780}