Skip to main content

torsh_cluster/initialization/
forgy.rs

1//! Forgy initialization strategy
2
3use super::InitializationStrategy;
4use crate::error::{ClusterError, ClusterResult};
5use scirs2_core::random::Random;
6use torsh_tensor::Tensor;
7
8/// Forgy initialization: randomly select k data points as centroids
9#[derive(Debug, Default)]
10pub struct Forgy;
11
12impl InitializationStrategy for Forgy {
13    fn initialize(
14        &self,
15        data: &Tensor,
16        n_clusters: usize,
17        seed: Option<u64>,
18    ) -> ClusterResult<Tensor> {
19        let n_samples = data.shape().dims()[0];
20        let n_features = data.shape().dims()[1];
21
22        if n_clusters > n_samples {
23            return Err(ClusterError::InvalidClusters(n_clusters));
24        }
25
26        let mut rng = Random::seed(seed.unwrap_or_else(|| {
27            use std::time::{SystemTime, UNIX_EPOCH};
28            SystemTime::now()
29                .duration_since(UNIX_EPOCH)
30                .expect("system time should be after UNIX_EPOCH")
31                .as_secs()
32        }));
33
34        let data_vec = data.to_vec().map_err(ClusterError::TensorError)?;
35        let mut selected = std::collections::HashSet::new();
36        let mut centroids_data = Vec::with_capacity(n_clusters * n_features);
37
38        for _ in 0..n_clusters {
39            let mut idx = rng.gen_range(0..n_samples);
40            while selected.contains(&idx) {
41                idx = rng.gen_range(0..n_samples);
42            }
43            selected.insert(idx);
44
45            for j in 0..n_features {
46                centroids_data.push(data_vec[idx * n_features + j]);
47            }
48        }
49
50        Tensor::from_vec(centroids_data, &[n_clusters, n_features])
51            .map_err(ClusterError::TensorError)
52    }
53
54    fn name(&self) -> &str {
55        "Forgy"
56    }
57}