1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::random::Rng;
9use std::fmt::Debug;
10
11use super::{euclidean_distance, kmeans_init, KMeansInit};
12use crate::error::{ClusteringError, Result};
13
14#[derive(Debug, Clone)]
16pub struct WeightedKMeansOptions<F: Float> {
17 pub max_iter: usize,
19 pub tol: F,
21 pub random_seed: Option<u64>,
23 pub n_init: usize,
25 pub init_method: KMeansInit,
27}
28
29impl<F: Float + FromPrimitive> Default for WeightedKMeansOptions<F> {
30 fn default() -> Self {
31 Self {
32 max_iter: 300,
33 tol: F::from(1e-4).unwrap(),
34 random_seed: None,
35 n_init: 10,
36 init_method: KMeansInit::KMeansPlusPlus,
37 }
38 }
39}
40
41#[allow(dead_code)]
80pub fn weighted_kmeans<F>(
81 data: ArrayView2<F>,
82 weights: ArrayView1<F>,
83 k: usize,
84 options: Option<WeightedKMeansOptions<F>>,
85) -> Result<(Array2<F>, Array1<usize>)>
86where
87 F: Float + FromPrimitive + Debug + std::iter::Sum,
88{
89 if k == 0 {
90 return Err(ClusteringError::InvalidInput(
91 "Number of clusters must be greater than 0".to_string(),
92 ));
93 }
94
95 let n_samples = data.shape()[0];
96 if n_samples == 0 {
97 return Err(ClusteringError::InvalidInput(
98 "Input data is empty".to_string(),
99 ));
100 }
101
102 if weights.len() != n_samples {
103 return Err(ClusteringError::InvalidInput(
104 "Weights array must have the same length as the number of samples".to_string(),
105 ));
106 }
107
108 if k > n_samples {
109 return Err(ClusteringError::InvalidInput(format!(
110 "Number of clusters ({}) cannot be greater than number of data points ({})",
111 k, n_samples
112 )));
113 }
114
115 for &weight in weights.iter() {
117 if weight < F::zero() {
118 return Err(ClusteringError::InvalidInput(
119 "All weights must be non-negative".to_string(),
120 ));
121 }
122 }
123
124 let opts = options.unwrap_or_default();
125
126 let mut bestcentroids = None;
127 let mut best_labels = None;
128 let mut best_inertia = F::infinity();
129
130 for _ in 0..opts.n_init {
131 let centroids = kmeans_init(data, k, Some(opts.init_method), opts.random_seed)?;
133
134 let (centroids, labels, inertia) =
136 weighted_kmeans_single(data, weights, centroids.view(), &opts)?;
137
138 if inertia < best_inertia {
139 bestcentroids = Some(centroids);
140 best_labels = Some(labels);
141 best_inertia = inertia;
142 }
143 }
144
145 Ok((bestcentroids.unwrap(), best_labels.unwrap()))
146}
147
148#[allow(dead_code)]
150fn weighted_kmeans_single<F>(
151 data: ArrayView2<F>,
152 weights: ArrayView1<F>,
153 initcentroids: ArrayView2<F>,
154 opts: &WeightedKMeansOptions<F>,
155) -> Result<(Array2<F>, Array1<usize>, F)>
156where
157 F: Float + FromPrimitive + Debug + std::iter::Sum,
158{
159 let n_samples = data.shape()[0];
160 let n_features = data.shape()[1];
161 let k = initcentroids.shape()[0];
162
163 let mut centroids = initcentroids.to_owned();
164 let mut labels = Array1::zeros(n_samples);
165 let mut prev_centroid_diff = F::infinity();
166
167 for _iter in 0..opts.max_iter {
168 let (new_labels, distances) = weighted_assign_labels(data, centroids.view())?;
170 labels = new_labels;
171
172 let mut newcentroids = Array2::zeros((k, n_features));
174 let mut total_weights = Array1::zeros(k);
175
176 for i in 0..n_samples {
177 let cluster = labels[i];
178 let point = data.slice(s![i, ..]);
179 let weight = weights[i];
180
181 for j in 0..n_features {
182 newcentroids[[cluster, j]] = newcentroids[[cluster, j]] + point[j] * weight;
183 }
184
185 total_weights[cluster] = total_weights[cluster] + weight;
186 }
187
188 for i in 0..k {
190 if total_weights[i] <= F::epsilon() {
191 let mut max_score = F::zero();
193 let mut far_idx = 0;
194
195 for j in 0..n_samples {
196 let score = weights[j] * distances[j];
197 if score > max_score {
198 max_score = score;
199 far_idx = j;
200 }
201 }
202
203 for j in 0..n_features {
205 newcentroids[[i, j]] = data[[far_idx, j]];
206 }
207
208 total_weights[i] = weights[far_idx];
209 } else {
210 for j in 0..n_features {
212 newcentroids[[i, j]] = newcentroids[[i, j]] / total_weights[i];
213 }
214 }
215 }
216
217 let mut centroid_diff = F::zero();
219 for i in 0..k {
220 let dist =
221 euclidean_distance(centroids.slice(s![i, ..]), newcentroids.slice(s![i, ..]));
222 centroid_diff = centroid_diff + dist;
223 }
224
225 centroids = newcentroids;
226
227 if centroid_diff <= opts.tol || centroid_diff >= prev_centroid_diff {
228 break;
229 }
230
231 prev_centroid_diff = centroid_diff;
232 }
233
234 let mut inertia = F::zero();
236 for i in 0..n_samples {
237 let cluster = labels[i];
238 let dist = euclidean_distance(data.slice(s![i, ..]), centroids.slice(s![cluster, ..]));
239 inertia = inertia + weights[i] * dist * dist;
240 }
241
242 Ok((centroids, labels, inertia))
243}
244
245#[allow(dead_code)]
247fn weighted_assign_labels<F>(
248 data: ArrayView2<F>,
249 centroids: ArrayView2<F>,
250) -> Result<(Array1<usize>, Array1<F>)>
251where
252 F: Float + FromPrimitive + Debug,
253{
254 let n_samples = data.shape()[0];
255 let k = centroids.shape()[0];
256
257 let mut labels = Array1::zeros(n_samples);
258 let mut distances = Array1::zeros(n_samples);
259
260 for i in 0..n_samples {
261 let point = data.slice(s![i, ..]);
262 let mut min_dist = F::infinity();
263 let mut closest_centroid = 0;
264
265 for j in 0..k {
266 let centroid = centroids.slice(s![j, ..]);
267 let dist = euclidean_distance(point, centroid);
268
269 if dist < min_dist {
270 min_dist = dist;
271 closest_centroid = j;
272 }
273 }
274
275 labels[i] = closest_centroid;
276 distances[i] = min_dist;
277 }
278
279 Ok((labels, distances))
280}
281
282#[allow(dead_code)]
299pub fn weighted_kmeans_plus_plus<F>(
300 data: ArrayView2<F>,
301 weights: ArrayView1<F>,
302 k: usize,
303 _random_seed: Option<u64>,
304) -> Result<Array2<F>>
305where
306 F: Float + FromPrimitive + Debug + std::iter::Sum,
307{
308 let n_samples = data.shape()[0];
309 let n_features = data.shape()[1];
310
311 if k == 0 || k > n_samples {
312 return Err(ClusteringError::InvalidInput(format!(
313 "Number of clusters ({}) must be between 1 and number of samples ({})",
314 k, n_samples
315 )));
316 }
317
318 if weights.len() != n_samples {
319 return Err(ClusteringError::InvalidInput(
320 "Weights array must have the same length as the number of samples".to_string(),
321 ));
322 }
323
324 let mut rng = scirs2_core::random::rng();
325
326 let mut centroids = Array2::zeros((k, n_features));
327
328 let total_weight: F = weights.iter().copied().sum();
330 let mut cumulative_weights = Array1::zeros(n_samples);
331 cumulative_weights[0] = weights[0] / total_weight;
332 for i in 1..n_samples {
333 cumulative_weights[i] = cumulative_weights[i - 1] + weights[i] / total_weight;
334 }
335
336 let rand_val = F::from(rng.random::<f64>()).unwrap();
337 let mut first_idx = 0;
338 for i in 0..n_samples {
339 if rand_val <= cumulative_weights[i] {
340 first_idx = i;
341 break;
342 }
343 }
344
345 for j in 0..n_features {
346 centroids[[0, j]] = data[[first_idx, j]];
347 }
348
349 if k == 1 {
350 return Ok(centroids);
351 }
352
353 for i in 1..k {
355 let mut weighted_distances = Array1::from_elem(n_samples, F::zero());
357
358 for sample_idx in 0..n_samples {
359 let sample = data.slice(s![sample_idx, ..]);
360 let mut min_dist_sq = F::infinity();
361
362 for centroid_idx in 0..i {
363 let centroid = centroids.slice(s![centroid_idx, ..]);
364 let dist = euclidean_distance(sample, centroid);
365 let dist_sq = dist * dist;
366
367 if dist_sq < min_dist_sq {
368 min_dist_sq = dist_sq;
369 }
370 }
371
372 weighted_distances[sample_idx] = weights[sample_idx] * min_dist_sq;
373 }
374
375 let sum_weighted_distances: F = weighted_distances.iter().copied().sum();
377 if sum_weighted_distances <= F::epsilon() {
378 let remaining_weight: F = weights.iter().copied().sum();
380 for sample_idx in 0..n_samples {
381 weighted_distances[sample_idx] = weights[sample_idx] / remaining_weight;
382 }
383 } else {
384 weighted_distances.mapv_inplace(|d| d / sum_weighted_distances);
385 }
386
387 let mut cum_weighted_distances = weighted_distances.clone();
389 for j in 1..n_samples {
390 cum_weighted_distances[j] = cum_weighted_distances[j] + cum_weighted_distances[j - 1];
391 }
392
393 let rand_val = F::from(rng.random::<f64>()).unwrap();
395 let mut next_idx = 0;
396
397 for j in 0..n_samples {
398 if rand_val <= cum_weighted_distances[j] {
399 next_idx = j;
400 break;
401 }
402 }
403
404 for j in 0..n_features {
406 centroids[[i, j]] = data[[next_idx, j]];
407 }
408 }
409
410 Ok(centroids)
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use approx::assert_abs_diff_eq;
417 use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
418
419 #[test]
420 fn test_weighted_kmeans_simple() {
421 let data = Array2::from_shape_vec(
423 (6, 2),
424 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
425 )
426 .unwrap();
427
428 let weights = Array1::from_elem(6, 1.0);
430
431 let options = WeightedKMeansOptions {
432 n_init: 1,
433 random_seed: Some(42),
434 ..Default::default()
435 };
436
437 let (centroids, labels) =
438 weighted_kmeans(data.view(), weights.view(), 2, Some(options)).unwrap();
439
440 assert_eq!(centroids.shape(), &[2, 2]);
442 assert_eq!(labels.len(), 6);
443
444 let unique_labels: Vec<_> = labels
446 .iter()
447 .copied()
448 .collect::<std::collections::HashSet<_>>()
449 .into_iter()
450 .collect();
451 assert_eq!(unique_labels.len(), 2);
452 }
453
454 #[test]
455 fn test_weighted_kmeans_different_weights() {
456 let data = Array2::from_shape_vec(
458 (6, 2),
459 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
460 )
461 .unwrap();
462
463 let weights = Array1::from_vec(vec![10.0, 10.0, 10.0, 1.0, 1.0, 1.0]);
465
466 let options = WeightedKMeansOptions {
467 n_init: 1,
468 random_seed: Some(42),
469 ..Default::default()
470 };
471
472 let (centroids, labels) =
473 weighted_kmeans(data.view(), weights.view(), 2, Some(options)).unwrap();
474
475 assert_eq!(centroids.shape(), &[2, 2]);
477 assert_eq!(labels.len(), 6);
478
479 let first_cluster_label = labels[0];
481 let first_centroid = if first_cluster_label == 0 { 0 } else { 1 };
482
483 let expected_centroid_x = (1.0 * 10.0 + 1.2 * 10.0 + 0.8 * 10.0) / (10.0 + 10.0 + 10.0);
486 let expected_centroid_y = (2.0 * 10.0 + 1.8 * 10.0 + 1.9 * 10.0) / (10.0 + 10.0 + 10.0);
487
488 let actual_centroid_x = centroids[[first_centroid, 0]];
489 let actual_centroid_y = centroids[[first_centroid, 1]];
490
491 assert_abs_diff_eq!(actual_centroid_x, expected_centroid_x, epsilon = 0.2);
493 assert_abs_diff_eq!(actual_centroid_y, expected_centroid_y, epsilon = 0.2);
494 }
495
496 #[test]
497 fn test_weighted_kmeans_plus_plus() {
498 let data = Array2::from_shape_vec(
499 (6, 2),
500 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
501 )
502 .unwrap();
503
504 let weights = Array1::from_vec(vec![1.0, 1.0, 1.0, 10.0, 10.0, 10.0]);
505
506 let centroids =
507 weighted_kmeans_plus_plus(data.view(), weights.view(), 2, Some(42)).unwrap();
508
509 assert_eq!(centroids.shape(), &[2, 2]);
511
512 for val in centroids.iter() {
514 assert!(val.is_finite());
515 }
516 }
517
518 #[test]
519 fn test_weighted_kmeans_zero_weights() {
520 let data =
521 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 4.0, 5.0, 4.2, 4.8]).unwrap();
522
523 let weights = Array1::from_vec(vec![1.0, 0.0, 1.0, 0.0]);
525
526 let options = WeightedKMeansOptions {
527 n_init: 1,
528 random_seed: Some(42),
529 ..Default::default()
530 };
531
532 let result = weighted_kmeans(data.view(), weights.view(), 2, Some(options));
533 assert!(result.is_ok());
534
535 let (centroids, labels) = result.unwrap();
536 assert_eq!(centroids.shape(), &[2, 2]);
537 assert_eq!(labels.len(), 4);
538 }
539
540 #[test]
541 fn test_weighted_kmeans_negative_weights() {
542 let data =
543 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 4.0, 5.0, 4.2, 4.8]).unwrap();
544
545 let weights = Array1::from_vec(vec![1.0, -1.0, 1.0, 1.0]);
547
548 let result = weighted_kmeans(data.view(), weights.view(), 2, None);
549 assert!(result.is_err());
550 }
551}