scirs2_cluster/vq/
weighted_kmeans.rs

1//! Weighted K-means clustering implementation
2//!
3//! This module provides K-means clustering with support for weighted samples,
4//! where each data point can have a different importance weight.
5
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::random::Rng;
9use std::fmt::Debug;
10
11use super::{euclidean_distance, kmeans_init, KMeansInit};
12use crate::error::{ClusteringError, Result};
13
14/// Options for weighted K-means clustering
15#[derive(Debug, Clone)]
16pub struct WeightedKMeansOptions<F: Float> {
17    /// Maximum number of iterations
18    pub max_iter: usize,
19    /// Convergence threshold for centroid movement
20    pub tol: F,
21    /// Random seed for initialization
22    pub random_seed: Option<u64>,
23    /// Number of different initializations to try
24    pub n_init: usize,
25    /// Method to use for centroid initialization
26    pub init_method: KMeansInit,
27}
28
29impl<F: Float + FromPrimitive> Default for WeightedKMeansOptions<F> {
30    fn default() -> Self {
31        Self {
32            max_iter: 300,
33            tol: F::from(1e-4).unwrap(),
34            random_seed: None,
35            n_init: 10,
36            init_method: KMeansInit::KMeansPlusPlus,
37        }
38    }
39}
40
41/// Weighted K-means clustering algorithm
42///
43/// This algorithm allows each data point to have a different weight,
44/// which affects both the centroid calculation and the overall objective function.
45///
46/// # Arguments
47///
48/// * `data` - Input data (n_samples × n_features)
49/// * `weights` - Sample weights (n_samples,). Higher weights mean more important samples
50/// * `k` - Number of clusters
51/// * `options` - Optional parameters
52///
53/// # Returns
54///
55/// * Tuple of (centroids, labels) where:
56///   - centroids: Array of shape (k × n_features)
57///   - labels: Array of shape (n_samples,) with cluster assignments
58///
59/// # Examples
60///
61/// ```
62/// use scirs2_core::ndarray::{ArrayView1, Array1, Array2};
63/// use scirs2_cluster::vq::weighted_kmeans;
64///
65/// let data = Array2::from_shape_vec((6, 2), vec![
66///     1.0, 2.0,
67///     1.2, 1.8,
68///     0.8, 1.9,
69///     3.7, 4.2,
70///     3.9, 3.9,
71///     4.2, 4.1,
72/// ]).unwrap();
73///
74/// // Give higher weight to the first three points
75/// let weights = Array1::from_vec(vec![2.0, 2.0, 2.0, 1.0, 1.0, 1.0]);
76///
77/// let (centroids, labels) = weighted_kmeans(data.view(), weights.view(), 2, None).unwrap();
78/// ```
79#[allow(dead_code)]
80pub fn weighted_kmeans<F>(
81    data: ArrayView2<F>,
82    weights: ArrayView1<F>,
83    k: usize,
84    options: Option<WeightedKMeansOptions<F>>,
85) -> Result<(Array2<F>, Array1<usize>)>
86where
87    F: Float + FromPrimitive + Debug + std::iter::Sum,
88{
89    if k == 0 {
90        return Err(ClusteringError::InvalidInput(
91            "Number of clusters must be greater than 0".to_string(),
92        ));
93    }
94
95    let n_samples = data.shape()[0];
96    if n_samples == 0 {
97        return Err(ClusteringError::InvalidInput(
98            "Input data is empty".to_string(),
99        ));
100    }
101
102    if weights.len() != n_samples {
103        return Err(ClusteringError::InvalidInput(
104            "Weights array must have the same length as the number of samples".to_string(),
105        ));
106    }
107
108    if k > n_samples {
109        return Err(ClusteringError::InvalidInput(format!(
110            "Number of clusters ({}) cannot be greater than number of data points ({})",
111            k, n_samples
112        )));
113    }
114
115    // Check that all weights are non-negative
116    for &weight in weights.iter() {
117        if weight < F::zero() {
118            return Err(ClusteringError::InvalidInput(
119                "All weights must be non-negative".to_string(),
120            ));
121        }
122    }
123
124    let opts = options.unwrap_or_default();
125
126    let mut bestcentroids = None;
127    let mut best_labels = None;
128    let mut best_inertia = F::infinity();
129
130    for _ in 0..opts.n_init {
131        // Initialize centroids using the specified method
132        let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
133
134        // Run weighted k-means
135        let (centroids, labels, inertia) =
136            weighted_kmeans_single(data, weights, centroids.view(), &opts)?;
137
138        if inertia < best_inertia {
139            bestcentroids = Some(centroids);
140            best_labels = Some(labels);
141            best_inertia = inertia;
142        }
143    }
144
145    Ok((bestcentroids.unwrap(), best_labels.unwrap()))
146}
147
148/// Run a single weighted k-means clustering iteration
149#[allow(dead_code)]
150fn weighted_kmeans_single<F>(
151    data: ArrayView2<F>,
152    weights: ArrayView1<F>,
153    initcentroids: ArrayView2<F>,
154    opts: &WeightedKMeansOptions<F>,
155) -> Result<(Array2<F>, Array1<usize>, F)>
156where
157    F: Float + FromPrimitive + Debug + std::iter::Sum,
158{
159    let n_samples = data.shape()[0];
160    let n_features = data.shape()[1];
161    let k = initcentroids.shape()[0];
162
163    let mut centroids = initcentroids.to_owned();
164    let mut labels = Array1::zeros(n_samples);
165    let mut prev_centroid_diff = F::infinity();
166
167    for _iter in 0..opts.max_iter {
168        // Assign samples to nearest centroid
169        let (new_labels, distances) = weighted_assign_labels(data, centroids.view())?;
170        labels = new_labels;
171
172        // Compute new centroids using weights
173        let mut newcentroids = Array2::zeros((k, n_features));
174        let mut total_weights = Array1::zeros(k);
175
176        for i in 0..n_samples {
177            let cluster = labels[i];
178            let point = data.slice(s![i, ..]);
179            let weight = weights[i];
180
181            for j in 0..n_features {
182                newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j] * weight;
183            }
184
185            total_weights[cluster] = total_weights[cluster] + weight;
186        }
187
188        // If a cluster is empty or has very low total weight, reinitialize it
189        for i in 0..k {
190            if total_weights[i] <= F::epsilon() {
191                // Find the point with highest weight * distance to its centroid
192                let mut max_score = F::zero();
193                let mut far_idx = 0;
194
195                for j in 0..n_samples {
196                    let score = weights[j] * distances[j];
197                    if score > max_score {
198                        max_score = score;
199                        far_idx = j;
200                    }
201                }
202
203                // Move this point to the empty cluster
204                for j in 0..n_features {
205                    newcentroids[[i, j]] = data[[far_idx, j]];
206                }
207
208                total_weights[i] = weights[far_idx];
209            } else {
210                // Normalize by the total weight in the cluster
211                for j in 0..n_features {
212                    newcentroids[[i, j]] = newcentroids[[i, j]] / total_weights[i];
213                }
214            }
215        }
216
217        // Check for convergence
218        let mut centroid_diff = F::zero();
219        for i in 0..k {
220            let dist =
221                euclidean_distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
222            centroid_diff = centroid_diff + dist;
223        }
224
225        centroids = newcentroids;
226
227        if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
228            break;
229        }
230
231        prev_centroid_diff = centroid_diff;
232    }
233
234    // Calculate weighted inertia (sum of weighted squared distances to nearest centroid)
235    let mut inertia = F::zero();
236    for i in 0..n_samples {
237        let cluster = labels[i];
238        let dist = euclidean_distance(data.slice(s![i, ..]), centroids.slice(s![cluster, ..]));
239        inertia = inertia + weights[i] * dist * dist;
240    }
241
242    Ok((centroids, labels, inertia))
243}
244
245/// Assign samples to nearest centroids (same as regular assignment)
246#[allow(dead_code)]
247fn weighted_assign_labels<F>(
248    data: ArrayView2<F>,
249    centroids: ArrayView2<F>,
250) -> Result<(Array1<usize>, Array1<F>)>
251where
252    F: Float + FromPrimitive + Debug,
253{
254    let n_samples = data.shape()[0];
255    let k = centroids.shape()[0];
256
257    let mut labels = Array1::zeros(n_samples);
258    let mut distances = Array1::zeros(n_samples);
259
260    for i in 0..n_samples {
261        let point = data.slice(s![i, ..]);
262        let mut min_dist = F::infinity();
263        let mut closest_centroid = 0;
264
265        for j in 0..k {
266            let centroid = centroids.slice(s![j, ..]);
267            let dist = euclidean_distance(point, centroid);
268
269            if dist < min_dist {
270                min_dist = dist;
271                closest_centroid = j;
272            }
273        }
274
275        labels[i] = closest_centroid;
276        distances[i] = min_dist;
277    }
278
279    Ok((labels, distances))
280}
281
282/// Weighted K-means++ initialization
283///
284/// This uses the weighted version of k-means++ where the probability of selecting
285/// a point as a centroid is proportional to its weight times its squared distance
286/// to the nearest existing centroid.
287///
288/// # Arguments
289///
290/// * `data` - Input data (n_samples × n_features)
291/// * `weights` - Sample weights (n_samples,)
292/// * `k` - Number of clusters
293/// * `random_seed` - Optional random seed
294///
295/// # Returns
296///
297/// * Array of shape (k × n_features) with initial centroids
298#[allow(dead_code)]
299pub fn weighted_kmeans_plus_plus<F>(
300    data: ArrayView2<F>,
301    weights: ArrayView1<F>,
302    k: usize,
303    _random_seed: Option<u64>,
304) -> Result<Array2<F>>
305where
306    F: Float + FromPrimitive + Debug + std::iter::Sum,
307{
308    let n_samples = data.shape()[0];
309    let n_features = data.shape()[1];
310
311    if k == 0 || k > n_samples {
312        return Err(ClusteringError::InvalidInput(format!(
313            "Number of clusters ({}) must be between 1 and number of samples ({})",
314            k, n_samples
315        )));
316    }
317
318    if weights.len() != n_samples {
319        return Err(ClusteringError::InvalidInput(
320            "Weights array must have the same length as the number of samples".to_string(),
321        ));
322    }
323
324    let mut rng = scirs2_core::random::rng();
325
326    let mut centroids = Array2::zeros((k, n_features));
327
328    // Choose the first centroid randomly with probability proportional to weights
329    let total_weight: F = weights.iter().copied().sum();
330    let mut cumulative_weights = Array1::zeros(n_samples);
331    cumulative_weights[0] = weights[0] / total_weight;
332    for i in 1..n_samples {
333        cumulative_weights[i] = cumulative_weights[i - 1] + weights[i] / total_weight;
334    }
335
336    let rand_val = F::from(rng.random::<f64>()).unwrap();
337    let mut first_idx = 0;
338    for i in 0..n_samples {
339        if rand_val <= cumulative_weights[i] {
340            first_idx = i;
341            break;
342        }
343    }
344
345    for j in 0..n_features {
346        centroids[[0, j]] = data[[first_idx, j]];
347    }
348
349    if k == 1 {
350        return Ok(centroids);
351    }
352
353    // Choose remaining centroids using weighted k-means++ algorithm
354    for i in 1..k {
355        // Compute weighted squared distances to closest centroid for each point
356        let mut weighted_distances = Array1::from_elem(n_samples, F::zero());
357
358        for sample_idx in 0..n_samples {
359            let sample = data.slice(s![sample_idx, ..]);
360            let mut min_dist_sq = F::infinity();
361
362            for centroid_idx in 0..i {
363                let centroid = centroids.slice(s![centroid_idx, ..]);
364                let dist = euclidean_distance(sample, centroid);
365                let dist_sq = dist * dist;
366
367                if dist_sq < min_dist_sq {
368                    min_dist_sq = dist_sq;
369                }
370            }
371
372            weighted_distances[sample_idx] = weights[sample_idx] * min_dist_sq;
373        }
374
375        // Normalize the weighted distances to create a probability distribution
376        let sum_weighted_distances: F = weighted_distances.iter().copied().sum();
377        if sum_weighted_distances <= F::epsilon() {
378            // If all weighted distances are zero, use uniform distribution among remaining points
379            let remaining_weight: F = weights.iter().copied().sum();
380            for sample_idx in 0..n_samples {
381                weighted_distances[sample_idx] = weights[sample_idx] / remaining_weight;
382            }
383        } else {
384            weighted_distances.mapv_inplace(|d| d / sum_weighted_distances);
385        }
386
387        // Convert to cumulative distribution
388        let mut cum_weighted_distances = weighted_distances.clone();
389        for j in 1..n_samples {
390            cum_weighted_distances[j] = cum_weighted_distances[j] + cum_weighted_distances[j - 1];
391        }
392
393        // Sample the next centroid based on the weighted probability distribution
394        let rand_val = F::from(rng.random::<f64>()).unwrap();
395        let mut next_idx = 0;
396
397        for j in 0..n_samples {
398            if rand_val <= cum_weighted_distances[j] {
399                next_idx = j;
400                break;
401            }
402        }
403
404        // Add the new centroid
405        for j in 0..n_features {
406            centroids[[i, j]] = data[[next_idx, j]];
407        }
408    }
409
410    Ok(centroids)
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use approx::assert_abs_diff_eq;
417    use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
418
419    #[test]
420    fn test_weighted_kmeans_simple() {
421        // Create a simple dataset with clear clusters
422        let data = Array2::from_shape_vec(
423            (6, 2),
424            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
425        )
426        .unwrap();
427
428        // Equal weights (should behave like regular k-means)
429        let weights = Array1::from_elem(6, 1.0);
430
431        let options = WeightedKMeansOptions {
432            n_init: 1,
433            random_seed: Some(42),
434            ..Default::default()
435        };
436
437        let (centroids, labels) =
438            weighted_kmeans(data.view(), weights.view(), 2, Some(options)).unwrap();
439
440        // Check dimensions
441        assert_eq!(centroids.shape(), &[2, 2]);
442        assert_eq!(labels.len(), 6);
443
444        // Check that we have 2 clusters
445        let unique_labels: Vec<_> = labels
446            .iter()
447            .copied()
448            .collect::<std::collections::HashSet<_>>()
449            .into_iter()
450            .collect();
451        assert_eq!(unique_labels.len(), 2);
452    }
453
454    #[test]
455    fn test_weighted_kmeans_different_weights() {
456        // Create a simple dataset
457        let data = Array2::from_shape_vec(
458            (6, 2),
459            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
460        )
461        .unwrap();
462
463        // Give higher weight to the first three points
464        let weights = Array1::from_vec(vec![10.0, 10.0, 10.0, 1.0, 1.0, 1.0]);
465
466        let options = WeightedKMeansOptions {
467            n_init: 1,
468            random_seed: Some(42),
469            ..Default::default()
470        };
471
472        let (centroids, labels) =
473            weighted_kmeans(data.view(), weights.view(), 2, Some(options)).unwrap();
474
475        // Check dimensions
476        assert_eq!(centroids.shape(), &[2, 2]);
477        assert_eq!(labels.len(), 6);
478
479        // The centroid of the first cluster should be closer to the weighted center of the first 3 points
480        let first_cluster_label = labels[0];
481        let first_centroid = if first_cluster_label == 0 { 0 } else { 1 };
482
483        // The first cluster centroid should be close to the mean of the first 3 points
484        // because they have much higher weights
485        let expected_centroid_x = (1.0 * 10.0 + 1.2 * 10.0 + 0.8 * 10.0) / (10.0 + 10.0 + 10.0);
486        let expected_centroid_y = (2.0 * 10.0 + 1.8 * 10.0 + 1.9 * 10.0) / (10.0 + 10.0 + 10.0);
487
488        let actual_centroid_x = centroids[[first_centroid, 0]];
489        let actual_centroid_y = centroids[[first_centroid, 1]];
490
491        // The centroids should be close to the expected weighted means
492        assert_abs_diff_eq!(actual_centroid_x, expected_centroid_x, epsilon = 0.2);
493        assert_abs_diff_eq!(actual_centroid_y, expected_centroid_y, epsilon = 0.2);
494    }
495
496    #[test]
497    fn test_weighted_kmeans_plus_plus() {
498        let data = Array2::from_shape_vec(
499            (6, 2),
500            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
501        )
502        .unwrap();
503
504        let weights = Array1::from_vec(vec![1.0, 1.0, 1.0, 10.0, 10.0, 10.0]);
505
506        let centroids =
507            weighted_kmeans_plus_plus(data.view(), weights.view(), 2, Some(42)).unwrap();
508
509        // Check dimensions
510        assert_eq!(centroids.shape(), &[2, 2]);
511
512        // All centroid values should be finite
513        for val in centroids.iter() {
514            assert!(val.is_finite());
515        }
516    }
517
518    #[test]
519    fn test_weighted_kmeans_zero_weights() {
520        let data =
521            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 4.0, 5.0, 4.2, 4.8]).unwrap();
522
523        // Some zero weights should still work
524        let weights = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0]);
525
526        let options = WeightedKMeansOptions {
527            n_init: 1,
528            random_seed: Some(42),
529            ..Default::default()
530        };
531
532        let result = weighted_kmeans(data.view(), weights.view(), 2, Some(options));
533        assert!(result.is_ok());
534
535        let (centroids, labels) = result.unwrap();
536        assert_eq!(centroids.shape(), &[2, 2]);
537        assert_eq!(labels.len(), 4);
538    }
539
540    #[test]
541    fn test_weighted_kmeans_negative_weights() {
542        let data =
543            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 4.0, 5.0, 4.2, 4.8]).unwrap();
544
545        // Negative weights should cause an error
546        let weights = Array1::from_vec(vec![1.0, -1.0, 1.0, 1.0]);
547
548        let result = weighted_kmeans(data.view(), weights.view(), 2, None);
549        assert!(result.is_err());
550    }
551}